diff --git a/.fides/db_dataset.yml b/.fides/db_dataset.yml index fc19fd6e19..131d224d97 100644 --- a/.fides/db_dataset.yml +++ b/.fides/db_dataset.yml @@ -1276,6 +1276,8 @@ dataset: - name: privacyrequest data_categories: [] fields: + - name: access_result_urls + data_categories: [user] - name: cancel_reason data_categories: [system.operations] - name: canceled_at @@ -1289,6 +1291,8 @@ dataset: data_categories: [system.operations] - name: external_id data_categories: [system.operations] + - name: filtered_final_upload + data_categories: [user] - name: finished_processing_at data_categories: [system.operations] - name: id @@ -2109,3 +2113,43 @@ dataset: data_categories: [ system.operations ] - name: is_eligible data_categories: [ system.operations ] + + data_categories: [system.operations] + - name: requesttask + fields: + - name: access_data + data_categories: [user] + - name: action_type + data_categories: [system] + - name: all_descendant_tasks + data_categories: [system] + - name: collection + data_categories: [system] + - name: collection_address + data_categories: [system] + - name: collection_name + data_categories: [system] + - name: consent_sent + data_categories: [system] + - name: created_at + data_categories: [system] + - name: data_for_erasures + data_categories: [user] + - name: dataset_name + data_categories: [system] + - name: downstream_tasks + data_categories: [system] + - name: id + data_categories: [system] + - name: privacy_request_id + data_categories: [system] + - name: rows_masked + data_categories: [system] + - name: status + data_categories: [system] + - name: traversal_details + data_categories: [system] + - name: updated_at + data_categories: [system] + - name: upstream_tasks + data_categories: [system] \ No newline at end of file diff --git a/.fides/fides.toml b/.fides/fides.toml index f336635430..a473e3a7e4 100644 --- a/.fides/fides.toml +++ b/.fides/fides.toml @@ -50,6 +50,7 @@ task_retry_backoff = 1 subject_identity_verification_required = false task_retry_count = 0 task_retry_delay = 1 +use_dsr_3_0 = false [admin_ui] enabled = true diff --git a/CHANGELOG.md b/CHANGELOG.md index f2211add1c..5aabd58407 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The types of changes are: ## [Unreleased](https://github.com/ethyca/fides/compare/2.34.0...main) ### Added +- Added DSR 3.0 Scheduling which supports running DSR's in parallel with first-class request tasks [#4760](https://github.com/ethyca/fides/pull/4760) - Added carets to collapsible sections in the overlay modal [#4793](https://github.com/ethyca/fides/pull/4793) - Added erasure support for OpenWeb [#4735](https://github.com/ethyca/fides/pull/4735) - Added support for configuration of pre-approval webhooks [#4795](https://github.com/ethyca/fides/pull/4795) diff --git a/pyproject.toml b/pyproject.toml index 3cbd2f2463..ad8ac68f16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ module = [ "jose.*", "jwt.*", "multidimensional_urlencode.*", + "networkx.*", "okta.*", "pandas.*", "plotly.*", @@ -65,7 +66,7 @@ module = [ "twilio.*", "uvicorn.*", "validators.*", - "pygtrie.*" + "pygtrie.*", ] ignore_missing_imports = true diff --git a/requirements.txt b/requirements.txt index 186d030d1f..60c6f1bd7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,7 +29,7 @@ multidimensional_urlencode==0.0.4 nh3==0.2.15 okta==2.7.0 openpyxl==3.0.9 -networkx==3.1 # added to help with privacy preference data migration +networkx==3.1 packaging==23.0 pandas==1.4.3 paramiko==3.4.0 diff --git a/src/fides/api/alembic/migrations/versions/55bedede956d_requesttask.py b/src/fides/api/alembic/migrations/versions/55bedede956d_requesttask.py new file mode 100644 index 0000000000..8407cb4c47 --- /dev/null +++ b/src/fides/api/alembic/migrations/versions/55bedede956d_requesttask.py @@ -0,0 +1,133 @@ +"""requesttask + +Revision ID: 55bedede956d +Revises: 6cfd59e7920a +Create Date: 2024-04-04 04:12:25.332952 + +""" +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "55bedede956d" +down_revision = "6cfd59e7920a" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "requesttask", + sa.Column("id", sa.String(length=255), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column("privacy_request_id", sa.String(), nullable=True), + sa.Column("collection_address", sa.String(), nullable=False), + sa.Column("dataset_name", sa.String(), nullable=False), + sa.Column("collection_name", sa.String(), nullable=False), + sa.Column("action_type", sa.String(), nullable=False), + sa.Column("status", sa.String(), nullable=False), + sa.Column( + "upstream_tasks", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.Column( + "downstream_tasks", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.Column( + "all_descendant_tasks", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "access_data", + sqlalchemy_utils.types.encrypted.encrypted_type.StringEncryptedType(), + nullable=True, + ), + sa.Column( + "data_for_erasures", + sqlalchemy_utils.types.encrypted.encrypted_type.StringEncryptedType(), + nullable=True, + ), + sa.Column("rows_masked", sa.Integer(), nullable=True), + sa.Column("consent_sent", sa.Boolean(), nullable=True), + sa.Column("collection", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column( + "traversal_details", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.ForeignKeyConstraint( + ["privacy_request_id"], ["privacyrequest.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_requesttask_action_type"), "requesttask", ["action_type"], unique=False + ) + op.create_index( + op.f("ix_requesttask_collection_address"), + "requesttask", + ["collection_address"], + unique=False, + ) + op.create_index( + op.f("ix_requesttask_collection_name"), + "requesttask", + ["collection_name"], + unique=False, + ) + op.create_index( + op.f("ix_requesttask_dataset_name"), + "requesttask", + ["dataset_name"], + unique=False, + ) + op.create_index(op.f("ix_requesttask_id"), "requesttask", ["id"], unique=False) + op.create_index( + op.f("ix_requesttask_privacy_request_id"), + "requesttask", + ["privacy_request_id"], + unique=False, + ) + op.create_index( + op.f("ix_requesttask_status"), "requesttask", ["status"], unique=False + ) + op.add_column( + "privacyrequest", + sa.Column( + "filtered_final_upload", + sqlalchemy_utils.types.encrypted.encrypted_type.StringEncryptedType(), + nullable=True, + ), + ) + op.add_column( + "privacyrequest", + sa.Column( + "access_result_urls", + sqlalchemy_utils.types.encrypted.encrypted_type.StringEncryptedType(), + nullable=True, + ), + ) + + +def downgrade(): + op.drop_column("privacyrequest", "access_result_urls") + op.drop_column("privacyrequest", "filtered_final_upload") + op.drop_index(op.f("ix_requesttask_status"), table_name="requesttask") + op.drop_index(op.f("ix_requesttask_privacy_request_id"), table_name="requesttask") + op.drop_index(op.f("ix_requesttask_id"), table_name="requesttask") + op.drop_index(op.f("ix_requesttask_dataset_name"), table_name="requesttask") + op.drop_index(op.f("ix_requesttask_collection_name"), table_name="requesttask") + op.drop_index(op.f("ix_requesttask_collection_address"), table_name="requesttask") + op.drop_index(op.f("ix_requesttask_action_type"), table_name="requesttask") + op.drop_table("requesttask") diff --git a/src/fides/api/api/v1/endpoints/privacy_request_endpoints.py b/src/fides/api/api/v1/endpoints/privacy_request_endpoints.py index d7dd17eadd..973e87db35 100644 --- a/src/fides/api/api/v1/endpoints/privacy_request_endpoints.py +++ b/src/fides/api/api/v1/endpoints/privacy_request_endpoints.py @@ -46,7 +46,7 @@ ValidationError, ) from fides.api.graph.config import CollectionAddress -from fides.api.graph.graph import DatasetGraph, Node +from fides.api.graph.graph import DatasetGraph from fides.api.graph.traversal import Traversal from fides.api.models.audit_log import AuditLog, AuditLogAction from fides.api.models.client import ClientDetail @@ -59,11 +59,13 @@ CheckpointActionRequired, ConsentRequest, ExecutionLog, + ExecutionLogStatus, PrivacyRequest, PrivacyRequestNotifications, PrivacyRequestStatus, ProvidedIdentity, ProvidedIdentityType, + RequestTask, ) from fides.api.oauth.utils import verify_callback_oauth, verify_oauth_client from fides.api.schemas.dataset import CollectionAddressResponse, DryRunDatasetResponse @@ -84,9 +86,9 @@ PrivacyRequestCreate, PrivacyRequestNotificationInfo, PrivacyRequestResponse, + PrivacyRequestTaskSchema, PrivacyRequestVerboseResponse, ReviewPrivacyRequestIds, - RowCountRequest, VerificationCode, ) from fides.api.schemas.redis_cache import Identity @@ -104,7 +106,7 @@ cache_data, ) from fides.api.task.filter_results import filter_data_categories -from fides.api.task.graph_task import EMPTY_REQUEST, collect_queries +from fides.api.task.graph_task import EMPTY_REQUEST, EMPTY_REQUEST_TASK, collect_queries from fides.api.task.task_resources import TaskResources from fides.api.tasks import MESSAGING_QUEUE_NAME from fides.api.util.api_router import APIRouter @@ -129,11 +131,10 @@ PRIVACY_REQUEST_AUTHENTICATED, PRIVACY_REQUEST_BULK_RETRY, PRIVACY_REQUEST_DENY, - PRIVACY_REQUEST_MANUAL_ERASURE, - PRIVACY_REQUEST_MANUAL_INPUT, PRIVACY_REQUEST_MANUAL_WEBHOOK_ACCESS_INPUT, PRIVACY_REQUEST_MANUAL_WEBHOOK_ERASURE_INPUT, PRIVACY_REQUEST_NOTIFICATIONS, + PRIVACY_REQUEST_REQUEUE, PRIVACY_REQUEST_RESUME, PRIVACY_REQUEST_RESUME_FROM_REQUIRES_INPUT, PRIVACY_REQUEST_RETRY, @@ -142,6 +143,7 @@ PRIVACY_REQUESTS, REQUEST_PREVIEW, REQUEST_STATUS_LOGS, + REQUEST_TASKS, V1_URL_PREFIX, ) from fides.config import CONFIG @@ -487,18 +489,8 @@ def attach_resume_instructions(privacy_request: PrivacyRequest) -> None: action_required_details: Optional[CheckpointActionRequired] = None if privacy_request.status == PrivacyRequestStatus.paused: - action_required_details = privacy_request.get_paused_collection_details() - - if action_required_details: - # Graph is paused on a specific collection - resume_endpoint = ( - PRIVACY_REQUEST_MANUAL_ERASURE - if action_required_details.step == CurrentStep.erasure - else PRIVACY_REQUEST_MANUAL_INPUT - ) - else: - # Graph is paused on a pre-processing webhook - resume_endpoint = PRIVACY_REQUEST_RESUME + # Graph is paused on a pre-processing webhook + resume_endpoint = PRIVACY_REQUEST_RESUME elif privacy_request.status == PrivacyRequestStatus.error: action_required_details = privacy_request.get_failed_checkpoint_details() @@ -797,9 +789,12 @@ def get_request_preview_queries( k: "something" for k in dataset_graph.identity_keys.values() } traversal: Traversal = Traversal(dataset_graph, identity_seed) + queries: Dict[CollectionAddress, str] = collect_queries( traversal, - TaskResources(EMPTY_REQUEST, Policy(), connection_configs, db), + TaskResources( + EMPTY_REQUEST, Policy(), connection_configs, EMPTY_REQUEST_TASK, db + ), ) return [ DryRunDatasetResponse( @@ -871,7 +866,7 @@ def validate_manual_input( """ for row in manual_rows: for field_name in row: - if not dataset_graph.nodes[collection].contains_field( + if not dataset_graph.nodes[collection].collection.contains_field( lambda f: f.name == field_name # pylint: disable=W0640 ): raise HTTPException( @@ -880,151 +875,6 @@ def validate_manual_input( ) -def resume_privacy_request_with_manual_input( - privacy_request_id: str, - db: Session, - expected_paused_step: CurrentStep, - manual_rows: List[Row] = [], - manual_count: Optional[int] = None, -) -> PrivacyRequest: - """Resume privacy request after validating and caching manual data for an access or an erasure request. - - This assumes the privacy request is being resumed from a specific collection in the graph. - """ - privacy_request: PrivacyRequest = get_privacy_request_or_error( - db, privacy_request_id - ) - if privacy_request.status != PrivacyRequestStatus.paused: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=f"Invalid resume request: privacy request '{privacy_request.id}' " # type: ignore - f"status = {privacy_request.status.value}. Privacy request is not paused.", - ) - - paused_details: Optional[ - CheckpointActionRequired - ] = privacy_request.get_paused_collection_details() - if not paused_details: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=f"Cannot resume privacy request '{privacy_request.id}'; no paused details.", - ) - - paused_step: CurrentStep = paused_details.step - paused_collection: Optional[CollectionAddress] = paused_details.collection - - if paused_step != expected_paused_step: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=f"Collection '{paused_collection}' is paused at the {paused_step.value} step. Pass in manual data instead to " - f"'{PRIVACY_REQUEST_MANUAL_ERASURE if paused_step == CurrentStep.erasure else PRIVACY_REQUEST_MANUAL_INPUT}' to resume.", - ) - - datasets = DatasetConfig.all(db=db) - dataset_graphs = [dataset_config.get_graph() for dataset_config in datasets] - dataset_graph = DatasetGraph(*dataset_graphs) - - if not paused_collection: - raise HTTPException( - status_code=HTTP_422_UNPROCESSABLE_ENTITY, - detail="Cannot save manual data on paused collection. No paused collection saved'.", - ) - - node: Optional[Node] = dataset_graph.nodes.get(paused_collection) - if not node: - raise HTTPException( - status_code=HTTP_422_UNPROCESSABLE_ENTITY, - detail=f"Cannot save manual data. No collection in graph with name: '{paused_collection.value}'.", - ) - - if paused_step == CurrentStep.access: - validate_manual_input(manual_rows, paused_collection, dataset_graph) - logger.info( - "Caching manual access input for privacy request '{}', collection: '{}'", - privacy_request_id, - paused_collection, - ) - privacy_request.cache_manual_access_input(paused_collection, manual_rows) - - elif paused_step == CurrentStep.erasure: - logger.info( - "Caching manually erased row count for privacy request '{}', collection: '{}'", - privacy_request_id, - paused_collection, - ) - privacy_request.cache_manual_erasure_count(paused_collection, manual_count) # type: ignore - - logger.info( - "Resuming privacy request '{}', {} step, from collection '{}'", - privacy_request_id, - paused_step.value, - paused_collection.value, - ) - - privacy_request.status = PrivacyRequestStatus.in_processing - privacy_request.save(db=db) - - queue_privacy_request( - privacy_request_id=privacy_request.id, - from_step=paused_step.value, - ) - - return privacy_request - - -@router.post( - PRIVACY_REQUEST_MANUAL_INPUT, - status_code=HTTP_200_OK, - response_model=PrivacyRequestResponse, - dependencies=[ - Security(verify_oauth_client, scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) - ], -) -def resume_with_manual_input( - privacy_request_id: str, - *, - db: Session = Depends(deps.get_db), - manual_rows: List[Row], -) -> PrivacyRequestResponse: - """Resume a privacy request by passing in manual input for the paused collection. - - If there's no manual data to submit, pass in an empty list to resume the privacy request. - """ - return resume_privacy_request_with_manual_input( - privacy_request_id=privacy_request_id, - db=db, - expected_paused_step=CurrentStep.access, - manual_rows=manual_rows, - ) # type: ignore[return-value] - - -@router.post( - PRIVACY_REQUEST_MANUAL_ERASURE, - status_code=HTTP_200_OK, - response_model=PrivacyRequestResponse, - dependencies=[ - Security(verify_oauth_client, scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) - ], -) -def resume_with_erasure_confirmation( - privacy_request_id: str, - *, - db: Session = Depends(deps.get_db), - cache: FidesopsRedis = Depends(deps.get_cache), - manual_count: RowCountRequest, -) -> PrivacyRequestResponse: - """Resume the erasure portion of privacy request by passing in the number of rows that were manually masked. - - If no rows were masked, pass in a 0 to resume the privacy request. - """ - return resume_privacy_request_with_manual_input( - privacy_request_id=privacy_request_id, - db=db, - expected_paused_step=CurrentStep.erasure, - manual_count=manual_count.row_count, - ) # type: ignore[return-value] - - @router.post( PRIVACY_REQUEST_BULK_RETRY, status_code=HTTP_200_OK, @@ -1070,7 +920,6 @@ def bulk_restart_privacy_request_from_failure( _process_privacy_request_restart( privacy_request, failed_details.step if failed_details else None, - failed_details.collection if failed_details else None, db, ) ) @@ -1109,7 +958,6 @@ def restart_privacy_request_from_failure( return _process_privacy_request_restart( privacy_request, failed_details.step if failed_details else None, - failed_details.collection if failed_details else None, db, ) @@ -1894,17 +1742,15 @@ def _create_or_update_custom_fields( def _process_privacy_request_restart( privacy_request: PrivacyRequest, failed_step: Optional[CurrentStep], - failed_collection: Optional[CollectionAddress], db: Session, ) -> PrivacyRequestResponse: - """If failed_step and failed_collection are provided, restart the DSR within that step. Otherwise, + """If failed_step is provided, restart the DSR within that step. Otherwise, restart the privacy request from the beginning.""" - if failed_step and failed_collection: + if failed_step: logger.info( - "Restarting failed privacy request '{}' from '{} step, 'collection '{}'", + "Restarting failed privacy request '{}' from '{}'", privacy_request.id, failed_step, - failed_collection, ) else: logger.info( @@ -1920,3 +1766,88 @@ def _process_privacy_request_restart( ) return privacy_request # type: ignore[return-value] + + +@router.get( + REQUEST_TASKS, + dependencies=[Security(verify_oauth_client, scopes=[PRIVACY_REQUEST_READ])], + response_model=List[PrivacyRequestTaskSchema], +) +def get_individual_privacy_request_tasks( + privacy_request_id: str, + *, + db: Session = Depends(deps.get_db), +) -> List[RequestTask]: + """Returns individual Privacy Request Tasks created by DSR 3.0 scheduler + in order by creation and collection address""" + pr: PrivacyRequest = get_privacy_request_or_error(db, privacy_request_id) + + logger.info(f"Getting Request Tasks for '{privacy_request_id}'") + + return pr.request_tasks.order_by( + RequestTask.created_at.asc(), RequestTask.collection_address.asc() + ).all() + + +@router.post( + PRIVACY_REQUEST_REQUEUE, + dependencies=[ + Security(verify_oauth_client, scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) + ], + response_model=PrivacyRequestResponse, +) +def requeue_privacy_request( + privacy_request_id: str, + *, + db: Session = Depends(deps.get_db), +) -> PrivacyRequestResponse: + """ + Endpoint for manually re-queuing a stuck Privacy Request from selected states - use with caution. + + Don't use this unless the Privacy Request is stuck. + """ + pr: PrivacyRequest = get_privacy_request_or_error(db, privacy_request_id) + + if pr.status not in [ + PrivacyRequestStatus.approved, + PrivacyRequestStatus.in_processing, + ]: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=f"Request failed. Cannot re-queue privacy request {pr.id} with status {pr.status.value}", + ) + + # Both DSR 2.0 and 3.0 cache checkpoint details + checkpoint_details: Optional[ + CheckpointActionRequired + ] = pr.get_failed_checkpoint_details() + resume_step = checkpoint_details.step if checkpoint_details else None + + # DSR 3.0 additionally stores Request Tasks in the application db that can be used to infer + # a resume checkpoint in the event the cache has expired. + if not resume_step and pr.request_tasks.count(): + if pr.consent_tasks.count(): + resume_step = CurrentStep.consent + elif pr.erasure_tasks.count(): + # Checking if access terminator task was completed, because erasure tasks are created + # at the same time as the access tasks + terminator_access_task = pr.get_terminate_task_by_action(ActionType.access) + resume_step = ( + CurrentStep.erasure + if terminator_access_task.status == ExecutionLogStatus.complete + else CurrentStep.access + ) + elif pr.access_tasks.count(): + resume_step = CurrentStep.access + + logger.info( + "Manually re-queuing Privacy Request {} from step {}", + pr, + resume_step.value if resume_step else None, + ) + + return _process_privacy_request_restart( + pr, + resume_step, + db, + ) diff --git a/src/fides/api/common_exceptions.py b/src/fides/api/common_exceptions.py index d1d26b311d..62026a9b40 100644 --- a/src/fides/api/common_exceptions.py +++ b/src/fides/api/common_exceptions.py @@ -91,6 +91,10 @@ class PolicyNotFoundException(Exception): """Policy could not be found""" +class ResumeTaskException(Exception): + """Issue restoring data from collection to resume Privacy Request Processing""" + + class ConnectorNotFoundException(Exception): """Connector could not be found""" @@ -139,6 +143,14 @@ class NotSupportedForCollection(BaseException): """The given action is not supported for this type of collection""" +class PrivacyRequestExit(BaseException): + """Privacy request exiting processing waiting on subtasks to complete""" + + +class PrivacyRequestCanceled(BaseException): + """Privacy Request has been Canceled""" + + class PrivacyRequestPaused(BaseException): """Halt Instruction Received on Privacy Request""" @@ -147,6 +159,14 @@ class PrivacyRequestNotFound(BaseException): """Privacy Request Not Found""" +class RequestTaskNotFound(BaseException): + """Privacy Request Task Not Found""" + + +class UpstreamTasksNotReady(BaseException): + """Privacy Request Task awaiting upstream tasks""" + + class NoCachedManualWebhookEntry(BaseException): """No manual data exists for this webhook on the given privacy request.""" diff --git a/src/fides/api/graph/analytics_events.py b/src/fides/api/graph/analytics_events.py deleted file mode 100644 index 6b50ff4adb..0000000000 --- a/src/fides/api/graph/analytics_events.py +++ /dev/null @@ -1,106 +0,0 @@ -from datetime import datetime, timezone -from typing import TYPE_CHECKING, Dict, List, Optional - -from fideslog.sdk.python.event import AnalyticsEvent - -from fides.api.analytics import in_docker_container, send_analytics_event -from fides.api.graph.config import CollectionAddress -from fides.api.graph.graph_differences import ( - GraphDiffSummary, - GraphRepr, - find_graph_differences_summary, - format_graph_for_caching, -) -from fides.api.models.privacy_request import PrivacyRequest -from fides.api.schemas.policy import ActionType -from fides.api.task.task_resources import TaskResources -from fides.api.util.collection_util import Row -from fides.config import CONFIG - -if TYPE_CHECKING: - from fides.api.task.graph_task import GraphTask - - -async def fideslog_graph_failure(event: Optional[AnalyticsEvent]) -> None: - """Send an Analytics Event if privacy request execution has failed""" - if CONFIG.user.analytics_opt_out or not event: - return - - await send_analytics_event(event) - - -async def fideslog_graph_rerun(event: Optional[AnalyticsEvent]) -> None: - """Send an Analytics Event if a privacy request has been reprocessed, comparing its graph to the previous graph""" - if CONFIG.user.analytics_opt_out or not event: - return - - await send_analytics_event(event) - - -def prepare_rerun_graph_analytics_event( - privacy_request: PrivacyRequest, - env: Dict[CollectionAddress, "GraphTask"], - end_nodes: List[CollectionAddress], - resources: TaskResources, - step: ActionType, -) -> Optional[AnalyticsEvent]: - """Prepares an AnalyticsEvent to send to Fideslog with stats on how an access graph - has changed from the previous run if applicable. - - Even for erasure requests, we still compare the "access graphs", because that reflects - what data has changed and the relationships between them. - The erasure graph is really just a list that runs each node with data from the access graphs. - """ - previous_graph: Optional[GraphRepr] = privacy_request.get_cached_access_graph() - current_graph: GraphRepr = format_graph_for_caching(env, end_nodes) - - previous_access_results: Dict[ - str, Optional[List[Row]] - ] = resources.get_all_cached_objects() - - previous_erasure_results: Dict[str, int] = {} - if step == ActionType.erasure: - # Don't bother looking this up if we are running this just for the access portion - previous_erasure_results = resources.get_all_cached_erasures() - - graph_diff_summary: Optional[GraphDiffSummary] = find_graph_differences_summary( - previous_graph, current_graph, previous_access_results, previous_erasure_results - ) - - if not graph_diff_summary: - return None - - data = graph_diff_summary.dict() - data["privacy_request"] = privacy_request.id - - return AnalyticsEvent( - docker=in_docker_container(), - event="rerun_access_graph" - if step == ActionType.access - else "rerun_erasure_graph", - event_created_at=datetime.now(tz=timezone.utc), - local_host=None, - endpoint=None, - status_code=None, - error=None, - extra_data=data, - ) - - -def failed_graph_analytics_event( - privacy_request: PrivacyRequest, exc: Optional[BaseException] -) -> Optional[AnalyticsEvent]: - """Prepares an AnalyticsEvent to send to Fideslog if privacy request execution has failed.""" - - data = {"privacy_request": privacy_request.id} - - return AnalyticsEvent( - docker=in_docker_container(), - event="privacy_request_execution_failure", - event_created_at=datetime.now(tz=timezone.utc), - local_host=None, - endpoint=None, - status_code=500, - error=exc.__class__.__name__ if exc else None, - extra_data=data, - ) diff --git a/src/fides/api/graph/config.py b/src/fides/api/graph/config.py index 06f20554f0..fdb21d9186 100644 --- a/src/fides/api/graph/config.py +++ b/src/fides/api/graph/config.py @@ -80,7 +80,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple +from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union from fideslang.validation import FidesKey from pydantic import BaseModel, validator @@ -215,6 +215,22 @@ def collection_address(self) -> CollectionAddress: """Return the collection prefix of this field address.""" return CollectionAddress(self.dataset, self.collection) + @staticmethod + def from_string(field_address_string: str) -> FieldAddress: + """Creates a Field Address from a string - especially useful for instantiating + Fields in Collections that are built from data in RequestTask.collection""" + try: + split_string = field_address_string.split(":") + dataset = split_string[0] + collection = split_string[1] + fields = split_string[2] + split_fields = fields.split(".") + return FieldAddress(dataset, collection, *split_fields) + except Exception: + raise FidesopsException( + f"'{field_address_string}' is not a valid field address" + ) + def __eq__(self, other: object) -> bool: if not isinstance(other, FieldAddress): return False @@ -487,10 +503,90 @@ def field_paths_by_category(self) -> Dict[FidesKey, List[FieldPath]]: categories[category].append(field_path) return categories + def contains_field(self, func: Callable[[Field], bool]) -> bool: + """True if any field in this collection matches the condition of the callable + + Currently used to assert at least one field in the collection contains a primary + key before erasing + """ + return any(self.recursively_collect_matches(func)) + + @classmethod + def parse_from_request_task(cls, data: Dict) -> Collection: + """ + Take raw collection data saved on RequestTask.collection and converts it back into a Collection. + + See Config > json_encoders for some of the fields that needed special handling for serialization for + database storage. + """ + data = data.copy() + + def build_field(serialized_field: dict) -> Field: + """Convert a serialized field on RequestTask.collection.fields into a Scalar Field + or Object Field""" + converted_references: List[ + Tuple[FieldAddress, Optional[EdgeDirection]] + ] = [] + for reference in serialized_field.pop("references", []): + field_address: str = reference[0] + edge_direction: Optional[str] = reference[1] + converted_references.append( + (FieldAddress.from_string(field_address), edge_direction) # type: ignore + ) + + data_type_converter: DataTypeConverter = get_data_type_converter( + serialized_field.pop("data_type_converter") + ) + + # We can't convert the fields to abstract class Field - they need to be proper + # Scalar or ObjectFields + converted: Union[ObjectField, ScalarField] + if serialized_field.get("fields"): + # Recursively build nested fields under Object field + serialized_field["fields"] = { + field_name: build_field(fld) + for field_name, fld in serialized_field["fields"].items() + } + converted = ObjectField.parse_obj(serialized_field) + converted.references = converted_references + converted.data_type_converter = data_type_converter + return converted + + converted = ScalarField.parse_obj(serialized_field) + converted.references = converted_references + converted.data_type_converter = data_type_converter + return converted + + converted_fields = [] + for field in data.pop("fields"): + converted_fields.append(build_field(field)) + + data["fields"] = converted_fields + data["after"] = { + CollectionAddress.from_string(addr_string) + for addr_string in data.get("after", []) + } + data["erase_after"] = { + CollectionAddress.from_string(addr_string) + for addr_string in data.get("erase_after", []) + } + + return Collection.parse_obj(data) + class Config: """for pydantic incorporation of custom non-pydantic types""" arbitrary_types_allowed = True + # This supports running Collection.json() to serialize less standard + # types so it can be saved to the database under RequestTask.collection + json_encoders = { + Set: lambda val: list( # pylint: disable=unhashable-member,unnecessary-lambda + val + ), + DataTypeConverter: lambda dtc: dtc.name if dtc.name else None, + FieldAddress: lambda fa: fa.value, + CollectionAddress: lambda ca: ca.value, + } class GraphDataset(BaseModel): diff --git a/src/fides/api/graph/data_type.py b/src/fides/api/graph/data_type.py index bd0aee1d8d..11d632149f 100644 --- a/src/fides/api/graph/data_type.py +++ b/src/fides/api/graph/data_type.py @@ -198,10 +198,11 @@ def is_valid_data_type(type_name: str) -> bool: def get_data_type_converter(type_name: Optional[str]) -> DataTypeConverter: - """Return the matching type converter. If an empty string or None is passed in + """Return the matching type converter. If an empty string or None or string None is passed in will return the No-op converter, so the converter will never be set to 'None'. - On an illegal key will raise a KeyError.""" - if not type_name: + + Only an illegal key will raise a KeyError.""" + if not type_name or type_name == "None": return DataType.no_op.value return DataType[type_name].value diff --git a/src/fides/api/graph/execution.py b/src/fides/api/graph/execution.py new file mode 100644 index 0000000000..5a7158808a --- /dev/null +++ b/src/fides/api/graph/execution.py @@ -0,0 +1,129 @@ +from typing import Any, Callable, Dict, List, Optional, Set, Tuple + +from fideslang.validation import FidesKey + +from fides.api.graph.config import ( + Collection, + CollectionAddress, + Field, + FieldAddress, + FieldPath, +) +from fides.api.graph.graph import Edge +from fides.api.models.privacy_request import RequestTask, TraversalDetails +from fides.api.util.collection_util import partition +from fides.api.util.logger_context_utils import Contextualizable, LoggerContextKeys + +COLLECTION_FIELD_PATH_MAP = Dict[CollectionAddress, List[Tuple[FieldPath, FieldPath]]] + + +class ExecutionNode(Contextualizable): # pylint: disable=too-many-instance-attributes + """Node for *executing* a task. This node only has knowledge of itself and its incoming and outgoing edges + + After we build the graph, we save details to RequestTasks in the database that are hydrated here to execute an individual + node without rebuilding the graph with traversal.traverse + """ + + def __init__(self, request_task: RequestTask): + assert request_task.collection # For mypy + self.collection: Collection = Collection.parse_from_request_task( + request_task.collection + ) + self.address: CollectionAddress = CollectionAddress.from_string( + request_task.collection_address + ) + traversal_details = TraversalDetails.parse_obj( + request_task.traversal_details or {} + ) + + self.incoming_edges: Set[Edge] = { + Edge(FieldAddress.from_string(edge[0]), FieldAddress.from_string(edge[1])) + for edge in traversal_details.incoming_edges + } + self.outgoing_edges: Set[Edge] = { + Edge(FieldAddress.from_string(edge[0]), FieldAddress.from_string(edge[1])) + for edge in traversal_details.outgoing_edges + } + self.connection_key: FidesKey = FidesKey( + traversal_details.dataset_connection_key + ) + + self.incoming_edges_by_collection: Dict[ + CollectionAddress, List[Edge] + ] = partition(self.incoming_edges, lambda e: e.f1.collection_address()) + + # Input should be passed into accessing data in this order + self.input_keys: List[CollectionAddress] = [ + CollectionAddress.from_string(input_key) + for input_key in traversal_details.input_keys + ] + self.grouped_fields = self.collection.grouped_inputs + + @property + def query_field_paths(self) -> Set[FieldPath]: + """ + All of the possible field paths that we can query for possible filter values. + These are field paths that are the ends of incoming edges. + """ + return {edge.f2.field_path for edge in self.incoming_edges} + + @property + def dependent_identity_fields(self) -> bool: + """If the current collection needs inputs from other collections, in addition to its seed data.""" + for field in self.grouped_fields: + if self.collection.field(FieldPath(field)).identity: # type: ignore + return True + return False + + def get_log_context(self) -> Dict[LoggerContextKeys, Any]: + return {LoggerContextKeys.collection: self.collection.name} + + def build_incoming_field_path_maps( + self, group_dependent_fields: bool = False + ) -> Tuple[COLLECTION_FIELD_PATH_MAP, COLLECTION_FIELD_PATH_MAP]: + """ + For each collection connected to the current collection, return a list of tuples + mapping the foreign field to the local field. This is used to process data from incoming collections + into the current collection. + + :param group_dependent_fields: Whether we should split the incoming fields into two groups: one whose + fields are completely independent of one another, and the other whose incoming data needs to stay linked together. + If False, all fields are returned in the first tuple, and the second tuple just maps collections to an empty list. + + """ + + def field_map(keep: Callable) -> COLLECTION_FIELD_PATH_MAP: + return { + col_addr: [ + (edge.f1.field_path, edge.f2.field_path) + for edge in edge_list + if keep(edge.f2.field_path.string_path) + ] + for col_addr, edge_list in self.incoming_edges_by_collection.items() + } + + if group_dependent_fields: + return field_map( + lambda string_path: string_path not in self.grouped_fields + ), field_map(lambda string_path: string_path in self.grouped_fields) + + return field_map(lambda string_path: True), field_map(lambda string_path: False) + + def typed_filtered_values(self, input_data: Dict[str, List[Any]]) -> Dict[str, Any]: + """ + Return a filtered list of key/value sets of data items that are both in + the list of incoming edge fields, and contain data in the input data set. + + The values are cast based on field types, if those types are specified. + """ + out = {} + for key, values in input_data.items(): + path: FieldPath = FieldPath.parse(key) + field: Optional[Field] = self.collection.field(path) + + if field and path in self.query_field_paths and isinstance(values, list): + cast_values = [field.cast(v) for v in values] + filtered = list(filter(lambda x: x is not None, cast_values)) + if filtered: + out[key] = filtered + return out diff --git a/src/fides/api/graph/graph.py b/src/fides/api/graph/graph.py index c53eba9371..393a09eafe 100644 --- a/src/fides/api/graph/graph.py +++ b/src/fides/api/graph/graph.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections import defaultdict -from typing import Callable, Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple from fideslang.validation import FidesKey from loguru import logger @@ -11,7 +11,6 @@ Collection, CollectionAddress, EdgeDirection, - Field, FieldAddress, FieldPath, GraphDataset, @@ -47,14 +46,6 @@ def __eq__(self, other: object) -> bool: def __hash__(self) -> int: return hash(self.address) - def contains_field(self, func: Callable[[Field], bool]) -> bool: - """True if any field in this collection matches the condition of the callable - - Currently used to assert at least one field in the collection contains a primary - key before erasing - """ - return any(self.collection.recursively_collect_matches(func)) - class Edge: """A graph link uniquely defined by a pair of keys and a direction from f1->f2. diff --git a/src/fides/api/graph/graph_differences.py b/src/fides/api/graph/graph_differences.py deleted file mode 100644 index 9a0899a28f..0000000000 --- a/src/fides/api/graph/graph_differences.py +++ /dev/null @@ -1,208 +0,0 @@ -from typing import TYPE_CHECKING, Dict, List, Optional, Set - -from fides.api.graph.config import ( - ROOT_COLLECTION_ADDRESS, - TERMINATOR_ADDRESS, - CollectionAddress, -) -from fides.api.schemas.base_class import FidesSchema -from fides.api.util.collection_util import Row - -if TYPE_CHECKING: - from fides.api.task.graph_task import GraphTask - -GraphRepr = Dict[str, Dict[str, List[str]]] - - -def format_graph_for_caching( - env: Dict[CollectionAddress, "GraphTask"], end_nodes: List[CollectionAddress] -) -> GraphRepr: - """ - Builds a representation of the current graph built for a privacy request (that includes its edges) - for caching in Redis. - - Requires the results of traversal.traverse(): - - the modified `env` - - and the outputted end_nodes, which are the final nodes without children - - Maps collections to their upstream dependencies and associated edges. The root is stored as having - no upstream collections and the terminator collection has no incoming edges. - - Example: - { - : { - : [edge between upstream and current], - : [edge between another upstream and current] - }, - : {}, - : { - : [], - : [] - } - } - """ - graph_repr: GraphRepr = { - collection.value: { - upstream_collection_address.value: [str(edge) for edge in edge_list] - for upstream_collection_address, edge_list in g_task.incoming_edges_by_collection.items() - } - for collection, g_task in env.items() - } - graph_repr[ROOT_COLLECTION_ADDRESS.value] = {} - graph_repr[TERMINATOR_ADDRESS.value] = { - end_node.value: [] for end_node in end_nodes - } - - return graph_repr - - -class GraphDiff(FidesSchema): - """A more detailed description about how two graphs differ. Do not send these details to FidesLog.""" - - previous_collections: List[str] = [] - current_collections: List[str] = [] - added_collections: List[str] = [] - removed_collections: List[str] = [] - added_edges: List[str] = [] - removed_edges: List[str] = [] - already_processed_access_collections: List[str] = [] - already_processed_erasure_collections: List[str] = [] - skipped_added_edges: List[str] = [] - - -class GraphDiffSummary(FidesSchema): - """A summary about how two graphs have changed. This can be sent to FidesLog.""" - - prev_collection_count: int = 0 - curr_collection_count: int = 0 - added_collection_count: int = 0 - removed_collection_count: int = 0 - added_edge_count: int = 0 - removed_edge_count: int = 0 - already_processed_access_collection_count: int = 0 - already_processed_erasure_collection_count: int = 0 - skipped_added_edge_count: int = 0 - - -artificial_collections: Set[str] = { - ROOT_COLLECTION_ADDRESS.value, - TERMINATOR_ADDRESS.value, -} - - -def get_skipped_added_edges( - already_processed_access_collections: List[str], - current_graph: GraphRepr, - added_edges: List[str], -) -> List[str]: - """ - Gets newly added edges *directly* upstream of an already-processed collection. - - Already-processed collections have their immediate upstream edges removed from the graph - when we reprocess an access portion of the request. We don't re-query collections that were already run: - we use saved incoming results from last time. - """ - added_upstream_edges: List[str] = [] - - for collection in already_processed_access_collections: - for _, upstream_edges in current_graph[collection].items(): - for edge in upstream_edges: - if edge in added_edges: - added_upstream_edges.append(edge) - return added_upstream_edges - - -def _find_graph_differences( - previous_graph: Optional[GraphRepr], - current_graph: GraphRepr, - previous_results: Dict[str, Optional[List[Row]]], - previous_erasure_results: Dict[str, int], -) -> Optional[GraphDiff]: - """ - Determine how/if a graph has changed from the previous run when a privacy request is reprocessed. - - Takes in the previous graph, the current graph, and any collections that already ran the first time (previous_results). - Where applicable, we also take in the erasure collections that have already run. The current design doesn't run - the access request on a collection or the erasure portion of the collection more than once. - """ - if not previous_graph: - return None - - def all_edges(graph: GraphRepr) -> Set[str]: - edge_list: List[str] = [] - for _, dependent_collections in graph.items(): - for _, edges in dependent_collections.items(): - if edges: - edge_list.extend(edges) - return set(edge_list) - - current_collections: Set[str] = ( - set(list(current_graph.keys())) - artificial_collections - ) - current_edges: Set[str] = all_edges(current_graph) - previous_collections: Set[str] = ( - set(list(previous_graph.keys())) - artificial_collections - ) - previous_edges: Set[str] = all_edges(previous_graph) - - added_collections: List[str] = list(current_collections - previous_collections) - added_edges: List[str] = list(current_edges - previous_edges) - removed_collections: List[str] = list(previous_collections - current_collections) - removed_edges: List[str] = list(previous_edges - current_edges) - - already_processed_access_collections = list(previous_results.keys()) - skipped_added_edges: List[str] = get_skipped_added_edges( - already_processed_access_collections, current_graph, added_edges - ) - - already_processed_erasure_collections = list(previous_erasure_results.keys()) - - return GraphDiff( - previous_collections=list(sorted(previous_collections)), - current_collections=list(sorted(current_collections)), - added_collections=sorted(added_collections), - removed_collections=sorted(removed_collections), - added_edges=sorted(added_edges), - removed_edges=sorted(removed_edges), - already_processed_access_collections=sorted( - already_processed_access_collections - ), - already_processed_erasure_collections=sorted( - already_processed_erasure_collections - ), - skipped_added_edges=sorted(skipped_added_edges), - ) - - -def find_graph_differences_summary( - previous_graph: Optional[GraphRepr], - current_graph: GraphRepr, - previous_results: Dict[str, Optional[List[Row]]], - previous_erasure_results: Dict[str, int], -) -> Optional[GraphDiffSummary]: - """ - Summarizes the differences between the current graph and previous graph - with a series of counts. - """ - graph_diff: Optional[GraphDiff] = _find_graph_differences( - previous_graph, current_graph, previous_results, previous_erasure_results - ) - - if not graph_diff: - return None - - return GraphDiffSummary( - prev_collection_count=len(graph_diff.previous_collections), - curr_collection_count=len(graph_diff.current_collections), - added_collection_count=len(graph_diff.added_collections), - removed_collection_count=len(graph_diff.removed_collections), - added_edge_count=len(graph_diff.added_edges), - removed_edge_count=len(graph_diff.removed_edges), - already_processed_access_collection_count=len( - graph_diff.already_processed_access_collections - ), - already_processed_erasure_collection_count=len( - graph_diff.already_processed_erasure_collections - ), - skipped_added_edge_count=len(graph_diff.skipped_added_edges), - ) diff --git a/src/fides/api/graph/traversal.py b/src/fides/api/graph/traversal.py index fddd39a9b3..73ef2babd1 100644 --- a/src/fides/api/graph/traversal.py +++ b/src/fides/api/graph/traversal.py @@ -1,31 +1,40 @@ from __future__ import annotations +import json from typing import Any, Callable, Dict, List, Set, Tuple, cast import pydash.collections +from fideslang.validation import FidesKey from loguru import logger from fides.api.common_exceptions import TraversalError from fides.api.graph.config import ( ROOT_COLLECTION_ADDRESS, + TERMINATOR_ADDRESS, Collection, CollectionAddress, - Field, FieldAddress, FieldPath, GraphDataset, ) +from fides.api.graph.execution import ExecutionNode from fides.api.graph.graph import DatasetGraph, Edge, Node -from fides.api.util.collection_util import Row, append +from fides.api.models.privacy_request import RequestTask, TraversalDetails +from fides.api.util.collection_util import Row, append, partition from fides.api.util.logger_context_utils import Contextualizable, LoggerContextKeys from fides.api.util.matching_queue import MatchingQueue +ARTIFICIAL_NODES: List[CollectionAddress] = [ + ROOT_COLLECTION_ADDRESS, + TERMINATOR_ADDRESS, +] + Datastore = Dict[CollectionAddress, List[Row]] """A type expressing retrieved rows of data from a specified collection""" class TraversalNode(Contextualizable): - """Base traversal traversal_node type. This type will never be used directly.""" + """Traversal_node type. This type is used for building the graph, not for executing the graph.""" def __init__(self, node: Node): self.node = node @@ -71,18 +80,6 @@ def incoming_edges(self) -> Set[Edge]: for _, parent_field_path, self_field_path in tuples } - def incoming_edges_from_same_dataset(self) -> Set[Edge]: - """Return the incoming edges from the same dataset""" - return { - Edge( - p_collection_address.field_address(parent_field_path), - self.address.field_address(self_field_path), - ) - for p_collection_address, tuples in self.parents.items() - if p_collection_address.dataset == self.address.dataset - for _, parent_field_path, self_field_path in tuples - } - def outgoing_edges(self) -> Set[Edge]: """Return the outgoing edges to this traversal_node,in (self.address -> other.address) order.""" return { @@ -94,36 +91,21 @@ def outgoing_edges(self) -> Set[Edge]: for _, self_field_path, child_field_path in tuples } - @property - def query_field_paths(self) -> Set[FieldPath]: - """ - All of the possible field paths that we can query for possible filter values. - These are field paths that are the ends of incoming edges. - """ - return {edge.f2.field_path for edge in self.incoming_edges()} + def incoming_edges_by_collection(self) -> Dict[CollectionAddress, List[Edge]]: + return partition(self.incoming_edges(), lambda e: e.f1.collection_address()) - def typed_filtered_values(self, input_data: Dict[str, List[Any]]) -> Dict[str, Any]: + def input_keys(self) -> List[CollectionAddress]: + """Returns the inputs to the current node that are data dependencies + This is copied and saved to the RequestTask and used to maintain a consistent order + for passing in data for an access task """ - Return a filtered list of key/value sets of data items that are both in - the list of incoming edge fields, and contain data in the input data set. - - The values are cast based on field types, if those types are specified. - """ - out = {} - for key, values in input_data.items(): - path: FieldPath = FieldPath.parse(key) - field: Field | None = self.node.collection.field(path) - - if field and path in self.query_field_paths and isinstance(values, list): - cast_values = [field.cast(v) for v in values] - filtered = list(filter(lambda x: x is not None, cast_values)) - if filtered: - out[key] = filtered - return out + return sorted(self.incoming_edges_by_collection().keys()) def can_run_given(self, remaining_node_keys: Set[CollectionAddress]) -> bool: """True if finished_node_keys covers all the nodes that this traversal_node is waiting for. If all nodes this traversal_node is waiting for have finished, it's ok for this traversal_node to run. + + NOTE: "After" functionality may not work as expected. """ if self.node.collection.after.intersection( remaining_node_keys @@ -165,6 +147,49 @@ def debug(self) -> Dict[str, Any]: def get_log_context(self) -> Dict[LoggerContextKeys, Any]: return {LoggerContextKeys.collection: self.node.collection.name} + def format_traversal_details_for_save(self) -> Dict: + """Convert key traversal details from the TraversalNode for save on the RequestTask. + + The RequestTask will be retrieved from the database and the traversal details + used to build the ExecutionNode for DSR 3.0. + """ + + connection_key: FidesKey = self.node.dataset.connection_key + + return TraversalDetails( + dataset_connection_key=connection_key, + incoming_edges=[ + [edge.f1.value, edge.f2.value] for edge in self.incoming_edges() + ], + outgoing_edges=[ + [edge.f1.value, edge.f2.value] for edge in self.outgoing_edges() + ], + input_keys=[tn.value for tn in self.input_keys()], + ).dict() + + def to_mock_request_task(self) -> RequestTask: + """Converts a portion of the TraversalNode into a RequestTask - used in building + dry run queries or for supporting Deprecated DSR 2.0. Request Tasks were introduced in DSR 3.0 + """ + collection_data = json.loads(self.node.collection.json()) + return RequestTask( # Mock a RequestTask object in memory + collection_address=self.node.address.value, + dataset_name=self.node.address.dataset, + collection_name=self.node.address.collection, + collection=collection_data, + traversal_details=self.format_traversal_details_for_save(), + ) + + def to_mock_execution_node(self) -> ExecutionNode: + """Converts a TraversalNode into an ExecutionNode - used for supporting DSR 2.0, to convert + Traversal Nodes into the Execution Node format which is needed for executing the graph in + DSR 3.0 + + DSR 3.0 on the other hand, creates ExecutionNodes from data on the RequestTask. + """ + request_task: RequestTask = self.to_mock_request_task() + return ExecutionNode(request_task) + def artificial_traversal_node(address: CollectionAddress) -> TraversalNode: """generate an 'artificial' traversal_node pointing to the given address. This is used to @@ -261,8 +286,6 @@ def traverse( # pylint: disable=R0914 Returns a list of termination traversal_node addresses so that we can take action on completed traversal. - - We define the root traversal_node as a traversal_node whose children are any nodes that have identity (seed) data. We start with diff --git a/src/fides/api/main.py b/src/fides/api/main.py index 022949bf03..60d7e7e2ff 100644 --- a/src/fides/api/main.py +++ b/src/fides/api/main.py @@ -32,6 +32,10 @@ from fides.api.service.privacy_request.email_batch_service import ( initiate_scheduled_batch_email_send, ) +from fides.api.service.privacy_request.request_service import ( + initiate_poll_for_exited_privacy_request_tasks, + initiate_scheduled_dsr_data_removal, +) from fides.api.tasks.scheduled.scheduler import async_scheduler, scheduler from fides.api.ui import ( get_admin_index_as_response, @@ -278,6 +282,8 @@ async def setup_server() -> None: async_scheduler.start() initiate_scheduled_batch_email_send() + initiate_poll_for_exited_privacy_request_tasks() + initiate_scheduled_dsr_data_removal() logger.debug("Sending startup analytics events...") # Avoid circular imports diff --git a/src/fides/api/models/connectionconfig.py b/src/fides/api/models/connectionconfig.py index fd86e73555..dbc6c65411 100644 --- a/src/fides/api/models/connectionconfig.py +++ b/src/fides/api/models/connectionconfig.py @@ -44,11 +44,11 @@ class ConnectionType(enum.Enum): mssql = "mssql" mariadb = "mariadb" bigquery = "bigquery" - manual = "manual" # Run as part of the traversal + manual = "manual" # Deprecated - use manual_webhook instead sovrn = "sovrn" attentive = "attentive" dynamodb = "dynamodb" - manual_webhook = "manual_webhook" # Run before the traversal + manual_webhook = "manual_webhook" # Runs upfront before the traversal timescale = "timescale" fides = "fides" generic_erasure_email = "generic_erasure_email" # Run after the traversal diff --git a/src/fides/api/models/policy.py b/src/fides/api/models/policy.py index 813e460206..24b0d078f9 100644 --- a/src/fides/api/models/policy.py +++ b/src/fides/api/models/policy.py @@ -33,8 +33,11 @@ class CurrentStep(EnumType): pre_webhooks = "pre_webhooks" access = "access" + upload_access = "upload_access" erasure = "erasure" + finalize_erasure = "finalize_erasure" consent = "consent" + finalize_consent = "finalize_consent" email_post_send = "email_post_send" post_webhooks = "post_webhooks" diff --git a/src/fides/api/models/privacy_request.py b/src/fides/api/models/privacy_request.py index 2959cc3585..41280d52cd 100644 --- a/src/fides/api/models/privacy_request.py +++ b/src/fides/api/models/privacy_request.py @@ -5,7 +5,7 @@ import json from datetime import datetime, timedelta from enum import Enum as EnumType -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from celery.result import AsyncResult from loguru import logger @@ -22,7 +22,8 @@ from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.ext.mutable import MutableDict, MutableList -from sqlalchemy.orm import Session, backref, relationship +from sqlalchemy.orm import Query, RelationshipProperty, Session, backref, relationship +from sqlalchemy.orm.dynamic import AppenderQuery from sqlalchemy_utils.types.encrypted.encrypted_type import ( AesGcmEngine, StringEncryptedType, @@ -38,8 +39,11 @@ from fides.api.db.base_class import Base # type: ignore[attr-defined] from fides.api.db.base_class import JSONTypeOverride from fides.api.db.util import EnumColumn -from fides.api.graph.config import CollectionAddress -from fides.api.graph.graph_differences import GraphRepr +from fides.api.graph.config import ( + ROOT_COLLECTION_ADDRESS, + TERMINATOR_ADDRESS, + CollectionAddress, +) from fides.api.models.audit_log import AuditLog from fides.api.models.client import ClientDetail from fides.api.models.fides_user import FidesUser @@ -71,6 +75,8 @@ from fides.api.util.cache import ( CustomJSONEncoder, FidesopsRedis, + _custom_decoder, + celery_tasks_in_flight, get_all_cache_keys_for_privacy_request, get_async_task_tracking_cache_key, get_cache, @@ -80,7 +86,7 @@ get_identity_cache_key, get_masking_secret_cache_key, ) -from fides.api.util.collection_util import Row +from fides.api.util.collection_util import Row, extract_key_for_address from fides.api.util.constants import API_DATE_FORMAT from fides.api.util.identity_verification import IdentityVerificationMixin from fides.api.util.logger_context_utils import Contextualizable, LoggerContextKeys @@ -91,8 +97,11 @@ EXECUTION_CHECKPOINTS = [ CurrentStep.pre_webhooks, CurrentStep.access, + CurrentStep.upload_access, CurrentStep.erasure, + CurrentStep.finalize_erasure, CurrentStep.consent, + CurrentStep.finalize_consent, CurrentStep.email_post_send, CurrentStep.post_webhooks, ] @@ -291,11 +300,38 @@ class PrivacyRequest( due_date = Column(DateTime(timezone=True), nullable=True) awaiting_email_send_at = Column(DateTime(timezone=True), nullable=True) + # Encrypted filtered access results saved for later retrieval + filtered_final_upload = Column( # An encrypted JSON String - Dict[Dict[str, List[Row]]] - rule keys mapped to the filtered access results + StringEncryptedType( + type_in=String(), + key=CONFIG.security.app_encryption_key, + engine=AesGcmEngine, + padding="pkcs5", + ), + ) + + # Encrypted filtered access results saved for later retrieval + access_result_urls = Column( # An encrypted JSON String - Dict[Dict[str, List[Row]]] - rule keys mapped to the filtered access results + StringEncryptedType( + type_in=JSONTypeOverride, + key=CONFIG.security.app_encryption_key, + engine=AesGcmEngine, + padding="pkcs5", + ), + ) + # Non-DB fields that are optionally added throughout the codebase action_required_details: Optional[CheckpointActionRequired] = None execution_and_audit_logs_by_dataset: Optional[property] = None resume_endpoint: Optional[str] = None + request_tasks: RelationshipProperty[AppenderQuery] = relationship( + "RequestTask", + back_populates="privacy_request", + lazy="dynamic", + order_by="RequestTask.created_at", + ) + @property def days_left(self: PrivacyRequest) -> Union[int, None]: if self.due_date is None: @@ -494,14 +530,6 @@ def verify_identity(self, db: Session, provided_code: str) -> "PrivacyRequest": self.save(db) return self - def cache_task_id(self, task_id: str) -> None: - """Sets a task_id for this privacy request's asynchronous execution.""" - cache: FidesopsRedis = get_cache() - cache.set( - get_async_task_tracking_cache_key(self.id), - task_id, - ) - def get_cached_task_id(self) -> Optional[str]: """Gets the cached task ID for this privacy request.""" cache: FidesopsRedis = get_cache() @@ -582,7 +610,9 @@ def get_cached_custom_privacy_request_fields(self) -> Dict[str, Any]: return result def get_results(self) -> Dict[str, Any]: - """Retrieves all cached identity data associated with this Privacy Request""" + """Retrieves all cached identity data associated with this Privacy Request + Just used in testing + """ cache: FidesopsRedis = get_cache() result_prefix = f"{self.id}__*" return cache.get_encoded_objects_by_prefix(result_prefix) @@ -649,28 +679,24 @@ def get_paused_collection_details( def cache_failed_checkpoint_details( self, step: Optional[CurrentStep] = None, - collection: Optional[CollectionAddress] = None, ) -> None: """ - Cache a checkpoint where the privacy request failed so we can later resume from this failure point. + Cache the checkpoint reached in the Privacy Request so it can be resumed from this point in + case of failure. - Cache details about the failed step and failed collection details (where applicable). - No specific input data is required to resume a failed request, so action_needed is None. """ cache_action_required( cache_key=f"FAILED_LOCATION__{self.id}", step=step, - collection=collection, + collection=None, # Deprecated for failed checkpoint details action_needed=None, ) def get_failed_checkpoint_details( self, ) -> Optional[CheckpointActionRequired]: - """Get details about the failed step (access or erasure) and collection that triggered failure. - - If DSR processing failed within the graph, this will let us know if we should resume privacy request execution - from the "access" or "erasure" portion of the privacy request flow. + """Get the latest checkpoint reached in Privacy Request processing so we know where to resume + in case of failure. """ return get_action_required_details(cached_key=f"EN_FAILED_LOCATION__{self.id}") @@ -796,71 +822,6 @@ def get_manual_webhook_erasure_input_non_strict( ).dict() return manual_webhook.empty_fields_dict - def cache_manual_access_input( - self, collection: CollectionAddress, manual_rows: Optional[List[Row]] - ) -> None: - """Cache manually added rows for the given CollectionAddress. This is for use by the *manual* connector which is integrated with the graph.""" - cache: FidesopsRedis = get_cache() - cache.set_encoded_object( - f"MANUAL_INPUT__{self.id}__{collection.value}", - manual_rows, - ) - - def get_manual_access_input( - self, collection: CollectionAddress - ) -> Optional[List[Row]]: - """Retrieve manually added rows from the cache for the given CollectionAddress. - Returns the manual data if it exists, otherwise None. - - This is for use by the *manual* connector which is integrated with the graph. - """ - cache: FidesopsRedis = get_cache() - cached_results: Optional[ - Dict[str, Optional[List[Row]]] - ] = cache.get_encoded_objects_by_prefix( - f"MANUAL_INPUT__{self.id}__{collection.value}" - ) - return list(cached_results.values())[0] if cached_results else None - - def cache_manual_erasure_count( - self, collection: CollectionAddress, count: int - ) -> None: - """Cache the number of rows manually masked for a given collection. - - This is for use by the *manual* connector which is integrated with the graph. - """ - cache: FidesopsRedis = get_cache() - cache.set_encoded_object( - f"MANUAL_MASK__{self.id}__{collection.value}", - count, - ) - - def get_manual_erasure_count(self, collection: CollectionAddress) -> Optional[int]: - """Retrieve number of rows manually masked for this collection from the cache. - - Cached as an integer to mimic what we return from erasures in an automated way. - This is for use by the *manual* connector which is integrated with the graph. - """ - cache: FidesopsRedis = get_cache() - prefix = f"MANUAL_MASK__{self.id}__{collection.value}" - value_dict: Optional[Dict[str, int]] = cache.get_encoded_objects_by_prefix( # type: ignore - prefix - ) - return list(value_dict.values())[0] if value_dict else None - - def cache_access_graph(self, value: GraphRepr) -> None: - """Cache a representation of the graph built for the access request""" - cache: FidesopsRedis = get_cache() - cache.set_encoded_object(f"ACCESS_GRAPH__{self.id}", value) - - def get_cached_access_graph(self) -> Optional[GraphRepr]: - """Fetch the graph built for the access request""" - cache: FidesopsRedis = get_cache() - value_dict: Optional[ - Dict[str, Optional[GraphRepr]] - ] = cache.get_encoded_objects_by_prefix(f"ACCESS_GRAPH__{self.id}") - return list(value_dict.values())[0] if value_dict else None - def cache_data_use_map(self, value: Dict[str, Set[str]]) -> None: """ Cache a dict of collections traversed in the privacy request @@ -971,19 +932,46 @@ def pause_processing_for_email_send(self, db: Session) -> None: self.status = PrivacyRequestStatus.awaiting_email_send self.save(db=db) + def get_request_task_celery_task_ids(self) -> List[str]: + """Returns the celery task ids for each of the Request Tasks (subtasks) + + It is possible Request Tasks get queued multiple times, so the celery task + id returned is the last celery task queued. + """ + request_task_celery_ids: List[str] = [] + for request_task in self.request_tasks: + request_task_id: Optional[str] = request_task.get_cached_task_id() + if request_task_id: + request_task_celery_ids.append(request_task_id) + return request_task_celery_ids + def cancel_processing(self, db: Session, cancel_reason: Optional[str]) -> None: - """Cancels a privacy request. Currently should only cancel 'pending' tasks""" + """Cancels a privacy request. Currently should only cancel 'pending' tasks + + Just in case, also tries to cancel sub tasks (Request Tasks) if applicable, + although these shouldn't exist if the Privacy Request is pending. + """ if self.canceled_at is None: self.status = PrivacyRequestStatus.canceled self.cancel_reason = cancel_reason self.canceled_at = datetime.utcnow() self.save(db) - task_id = self.get_cached_task_id() - if task_id: - logger.info("Revoking task {} for request {}", task_id, self.id) - # Only revokes if execution is not already in progress - celery_app.control.revoke(task_id, terminate=False) + task_ids: List[ + str + ] = ( + self.get_request_task_celery_task_ids() + ) # Celery tasks for sub tasks (DSR 3.0 Request Tasks) + parent_task_id = ( + self.get_cached_task_id() + ) # Celery task for current Privacy Request + if parent_task_id: + task_ids.append(parent_task_id) + + for celery_task_id in task_ids: + logger.info("Revoking task {} for request {}", celery_task_id, self.id) + # Only revokes if execution is not already in progress. + celery_app.control.revoke(celery_task_id, terminate=False) def error_processing(self, db: Session) -> None: """Mark privacy request as errored, and note time processing was finished""" @@ -1002,6 +990,173 @@ def error_processing(self, db: Session) -> None: def get_log_context(self) -> Dict[LoggerContextKeys, Any]: return {LoggerContextKeys.privacy_request_id: self.id} + @property + def access_tasks(self) -> Query: + """Return existing Access Request Tasks for the current privacy request""" + return self.request_tasks.filter(RequestTask.action_type == ActionType.access) + + @property + def erasure_tasks(self) -> Query: + """Return existing Erasure Request Tasks for the current privacy request""" + return self.request_tasks.filter(RequestTask.action_type == ActionType.erasure) + + @property + def consent_tasks(self) -> Query: + """Return existing Consent Request Tasks for the current privacy request""" + return self.request_tasks.filter(RequestTask.action_type == ActionType.consent) + + def get_existing_request_task( + self, + db: Session, + action_type: ActionType, + collection_address: CollectionAddress, + ) -> Optional[RequestTask]: + """Returns a Request Task for the current Privacy Request with action type and collection address""" + return ( + db.query(RequestTask) + .filter( + RequestTask.privacy_request_id == self.id, + RequestTask.action_type == action_type, + RequestTask.collection_address == collection_address.value, + ) + .first() + ) + + def get_tasks_by_action(self, action: ActionType) -> Query: + """Convenience helper to get RequestTasks of a certain action type for the given + privacy request""" + if action == ActionType.access: + return self.access_tasks + + if action == ActionType.erasure: + return self.erasure_tasks + + if action == ActionType.consent: + return self.consent_tasks + + raise Exception(f"Unsupported Privacy Request Action Type {action}") + + def get_root_task_by_action(self, action: ActionType) -> RequestTask: + """Get the root tasks for a specific action""" + root: Optional[RequestTask] = ( + self.get_tasks_by_action(action) + .filter(RequestTask.collection_address == ROOT_COLLECTION_ADDRESS.value) + .first() + ) + if not root: + raise Exception( + f"Expected {action.value.capitalize()} root node cannot be found on privacy request {self.id} " + ) + assert root # for mypy + return root + + def get_terminate_task_by_action(self, action: ActionType) -> RequestTask: + """Get the terminate task for a specific action""" + terminate: Optional[RequestTask] = ( + self.get_tasks_by_action(action) + .filter(RequestTask.collection_address == TERMINATOR_ADDRESS.value) + .first() + ) + if not terminate: + raise Exception( + f"Expected {action.value.capitalize()} terminate node cannot be found on privacy request {self.id} " + ) + assert terminate # for mypy + return terminate + + def get_raw_access_results(self) -> Dict[str, Optional[List[Row]]]: + """Retrieve the *raw* access data saved on the individual access nodes + + These shouldn't be returned to the user - they are not filtered by data category + """ + # For DSR 3.0, pull these off of the RequestTask.access_data fields + if self.access_tasks.count(): + final_results: Dict = {} + for task in self.access_tasks.filter( + RequestTask.status == PrivacyRequestStatus.complete, + RequestTask.collection_address.notin_( + [ROOT_COLLECTION_ADDRESS.value, TERMINATOR_ADDRESS.value] + ), + ): + final_results[task.collection_address] = task.get_decoded_access_data() + + return final_results + + # TODO Remove when we stop support for DSR 2.0 + # We will no longer be pulling access results from the cache, but off of Request Tasks instead + cache: FidesopsRedis = get_cache() + value_dict = cache.get_encoded_objects_by_prefix(f"{self.id}__access_request") + # extract request id to return a map of address:value + number_of_leading_strings_to_exclude = 2 + return { + extract_key_for_address(k, number_of_leading_strings_to_exclude): v + for k, v in value_dict.items() + } + + def get_raw_masking_counts(self) -> Dict[str, int]: + """For parity, return the rows masked for an erasure request + + This is largely just used for testing + """ + if self.erasure_tasks.count(): + # For DSR 3.0 + return { + t.collection_address: t.rows_masked + for t in self.erasure_tasks.filter( + RequestTask.status.in_(COMPLETED_EXECUTION_LOG_STATUSES) + ) + if not t.is_root_task and not t.is_terminator_task + } + + # TODO Remove when we stop support for DSR 2.0 + cache: FidesopsRedis = get_cache() + value_dict = cache.get_encoded_objects_by_prefix(f"{self.id}__erasure_request") + # extract request id to return a map of address:value + number_of_leading_strings_to_exclude = 2 + return {extract_key_for_address(k, number_of_leading_strings_to_exclude): v for k, v in value_dict.items()} # type: ignore + + def get_consent_results(self) -> Dict[str, int]: + """For parity, return whether a consent request was sent for third + party consent propagation + + This is largely just used for testing + """ + if self.consent_tasks.count(): + # For DSR 3.0 + return { + t.collection_address: t.consent_sent + for t in self.consent_tasks.filter( + RequestTask.status.in_(EXITED_EXECUTION_LOG_STATUSES) + ) + if not t.is_root_task and not t.is_terminator_task + } + # DSR 2.0 does not cache the results so nothing to do here + return {} + + def save_filtered_access_results( + self, db: Session, results: Dict[str, Dict[str, List[Row]]] + ) -> None: + """ + For access requests, save the access data filtered by data category that we uploaded to the end user + + This is keyed by policy rule key, because we uploaded different packages for different policy rules + + """ + if not self.policy.get_rules_for_action(action_type=ActionType.access): + return None + + self.filtered_final_upload = json.dumps(results, cls=CustomJSONEncoder) + self.save(db) + + return None + + def get_filtered_access_results(self) -> Dict[str, Dict[str, List[Row]]]: + """Fetched the same filtered access results we uploaded to the user""" + return json.loads( + self.filtered_final_upload or "{}", + object_hook=_custom_decoder, + ) + class PrivacyRequestError(Base): """The DB ORM model to track PrivacyRequests error message status.""" @@ -1395,6 +1550,17 @@ class ExecutionLogStatus(EnumType): skipped = "skipped" +COMPLETED_EXECUTION_LOG_STATUSES = [ + ExecutionLogStatus.complete, + ExecutionLogStatus.skipped, +] +EXITED_EXECUTION_LOG_STATUSES = [ + ExecutionLogStatus.complete, + ExecutionLogStatus.error, + ExecutionLogStatus.skipped, +] + + class ExecutionLog(Base): """ Stores the individual execution logs associated with a PrivacyRequest. @@ -1465,3 +1631,223 @@ def _parse_cache_to_checkpoint_action_required( collection=collection, action_needed=action_needed, ) + + +class TraversalDetails(FidesSchema): + """Schema to format saving pre-calculated traversal details on RequestTask.traversal_details""" + + dataset_connection_key: str + incoming_edges: List[Tuple[str, str]] + outgoing_edges: List[Tuple[str, str]] + input_keys: List[str] + + +class RequestTask(Base): + """ + An individual Task for a Privacy Request. + + When we execute a PrivacyRequest, we build a graph by combining the current datasets with the identity data + and we save the nodes (collections) in the graph as Request Tasks. + + Currently, we build access, erasure, and consent Request Tasks. + """ + + privacy_request_id = Column( + String, + ForeignKey(PrivacyRequest.id_field_path, ondelete="SET NULL"), + nullable=True, + index=True, + ) + + # Identifiers of this request task + collection_address = Column( + String, nullable=False, index=True + ) # Of the format dataset_name:collection_name for convenience + dataset_name = Column(String, nullable=False, index=True) + collection_name = Column(String, nullable=False, index=True) + action_type = Column(EnumColumn(ActionType), nullable=False, index=True) + + status = Column( + EnumColumn(ExecutionLogStatus), # character varying in database + index=True, + nullable=False, + ) + + upstream_tasks = Column( + MutableList.as_mutable(JSONB) + ) # List of collection address strings + downstream_tasks = Column( + MutableList.as_mutable(JSONB) + ) # List of collection address strings + all_descendant_tasks = Column( + MutableList.as_mutable(JSONB) + ) # All tasks that can be reached by the current task. This is useful when this task fails, + # and we can mark every single one of these as failed. + + # Raw data retrieved from an access request is stored here. This contains all of the + # intermediate data we retrieved, needed for downstream tasks, but hasn't been filtered + # by data category for the end user. + access_data = Column( # An encrypted JSON String - saved as a list of Rows + StringEncryptedType( + type_in=String(), + key=CONFIG.security.app_encryption_key, + engine=AesGcmEngine, + padding="pkcs5", + ), + ) + + # This is the raw access data saved in erasure format (with placeholders preserved) to perform a masking request. + # First saved on the access node, and then copied to the corresponding erasure node. + data_for_erasures = Column( # An encrypted JSON String - saved as a list of rows + StringEncryptedType( + type_in=String(), + key=CONFIG.security.app_encryption_key, + engine=AesGcmEngine, + padding="pkcs5", + ), + ) + + # Written after an erasure is completed + rows_masked = Column(Integer) + # Written after a consent request is completed - not all consent + # connectors will end up sending a request + consent_sent = Column(Boolean) + + # Stores a serialized collection that can be transformed back into a Collection to help + # execute the current task + collection = Column(MutableDict.as_mutable(JSONB)) + # Stores key details from traversal.traverse in the format of TraversalDetails + traversal_details = Column(MutableDict.as_mutable(JSONB)) + + privacy_request: RelationshipProperty[PrivacyRequest] = relationship( + "PrivacyRequest", + back_populates="request_tasks", + uselist=False, + ) + + @property + def request_task_address(self) -> CollectionAddress: + """Convert the collection_address into Collection Address format""" + return CollectionAddress.from_string(self.collection_address) + + @property + def is_root_task(self) -> bool: + """Convenience helper for asserting whether the task is a root task""" + return self.request_task_address == ROOT_COLLECTION_ADDRESS + + @property + def is_terminator_task(self) -> bool: + """Convenience helper for asserting whether the task is a terminator task""" + return self.request_task_address == TERMINATOR_ADDRESS + + def get_cached_task_id(self) -> Optional[str]: + """Gets the cached celery task ID for this request task.""" + cache: FidesopsRedis = get_cache() + task_id = cache.get(get_async_task_tracking_cache_key(self.id)) + return task_id + + def get_decoded_access_data(self) -> List[Row]: + """Decode the collected access data""" + return json.loads(self.access_data or "[]", object_hook=_custom_decoder) + + def get_decoded_data_for_erasures(self) -> List[Row]: + """Decode the erasure data needed to build masking requests""" + return json.loads(self.data_for_erasures or "[]", object_hook=_custom_decoder) + + def update_status(self, db: Session, status: ExecutionLogStatus) -> None: + """Helper method to update a task's status""" + self.status = status + self.save(db) + + def get_tasks_with_same_action_type( + self, db: Session, collection_address_str: str + ) -> Query: + """Fetch task on the same privacy request and action type as current by collection address""" + return db.query(RequestTask).filter( + RequestTask.privacy_request_id == self.privacy_request_id, + RequestTask.action_type == self.action_type, + RequestTask.collection_address == collection_address_str, + ) + + def get_pending_downstream_tasks(self, db: Session) -> Query: + """Returns the immediate downstream task objects that are still pending""" + return db.query(RequestTask).filter( + RequestTask.privacy_request_id == self.privacy_request_id, + RequestTask.action_type == self.action_type, + RequestTask.collection_address.in_(self.downstream_tasks or []), + RequestTask.status == ExecutionLogStatus.pending, + ) + + def can_queue_request_task(self, db: Session, should_log: bool = False) -> bool: + """Returns True if upstream tasks are complete and the current Request Task + is not running in another celery task. + + This check ignores its database status - that is checked elsewhere. + """ + return self.upstream_tasks_complete( + db, should_log + ) and not self.request_task_running(should_log) + + def upstream_tasks_complete(self, db: Session, should_log: bool = False) -> bool: + """Determines if all of the upstream tasks of the current task are complete""" + upstream_tasks: Query = self.upstream_tasks_objects(db) + tasks_complete: bool = all( + upstream_task.status in COMPLETED_EXECUTION_LOG_STATUSES + for upstream_task in upstream_tasks + ) and upstream_tasks.count() == len(self.upstream_tasks or []) + + if not tasks_complete and should_log: + logger.debug( + "Upstream tasks incomplete for {} task {}. Privacy Request: {}, Request Task {}.", + self.action_type.value, + self.collection_address, + self.privacy_request_id, + self.id, + ) + + return tasks_complete + + def upstream_tasks_objects(self, db: Session) -> Query: + """Returns Request Task objects that are upstream of the current Request Task""" + upstream_tasks: Query = db.query(RequestTask).filter( + RequestTask.privacy_request_id == self.privacy_request_id, + RequestTask.collection_address.in_(self.upstream_tasks or []), + RequestTask.action_type == self.action_type, + ) + return upstream_tasks + + def request_task_running(self, should_log: bool = False) -> bool: + """Returns a rough measure if the Request Task is already running - + not 100% accurate. + + This is further only applicable if you are running workers and + CONFIG.execution.task_always_eager=False. This is just an extra check to reduce possible + over-scheduling, but it is also okay if the same node runs multiple times. + """ + celery_task_id: Optional[str] = self.get_cached_task_id() + if not celery_task_id: + return False + + if should_log: + logger.debug( + "Celery Task ID {} found for {} task {}. Privacy Request: {}, Request Task {}.", + celery_task_id, + self.action_type.value, + self.collection_address, + self.privacy_request_id, + self.id, + ) + + task_in_flight: bool = celery_tasks_in_flight([celery_task_id]) + + if task_in_flight and should_log: + logger.debug( + "Celery Task {} already processing for {} task {}. Privacy Request: {}, Request Task {}.", + celery_task_id, + self.action_type.value, + self.collection_address, + self.privacy_request_id, + self.id, + ) + + return task_in_flight diff --git a/src/fides/api/schemas/privacy_request.py b/src/fides/api/schemas/privacy_request.py index c1ba35dfd9..3e732f70fe 100644 --- a/src/fides/api/schemas/privacy_request.py +++ b/src/fides/api/schemas/privacy_request.py @@ -136,6 +136,19 @@ class Config: use_enum_values = True +class PrivacyRequestTaskSchema(FidesSchema): + """Schema for Privacy Request Tasks, which are individual nodes that are queued""" + + id: str + collection_address: str + status: ExecutionLogStatus + created_at: datetime + updated_at: datetime + upstream_tasks: List[str] + downstream_tasks: List[str] + action_type: ActionType + + class ExecutionLogDetailResponse(ExecutionLogResponse): """Schema for the detailed ExecutionLogs when accessed directly""" diff --git a/src/fides/api/service/connectors/__init__.py b/src/fides/api/service/connectors/__init__.py index c0f6f80f1d..4b32205576 100644 --- a/src/fides/api/service/connectors/__init__.py +++ b/src/fides/api/service/connectors/__init__.py @@ -24,9 +24,6 @@ FidesConnector as FidesConnector, ) from fides.api.service.connectors.http_connector import HTTPSConnector as HTTPSConnector -from fides.api.service.connectors.manual_connector import ( - ManualConnector as ManualConnector, -) from fides.api.service.connectors.manual_webhook_connector import ( ManualWebhookConnector as ManualWebhookConnector, ) @@ -65,7 +62,6 @@ ConnectionType.generic_consent_email.value: GenericConsentEmailConnector, ConnectionType.generic_erasure_email.value: GenericErasureEmailConnector, ConnectionType.https.value: HTTPSConnector, - ConnectionType.manual.value: ManualConnector, ConnectionType.manual_webhook.value: ManualWebhookConnector, ConnectionType.mariadb.value: MariaDBConnector, ConnectionType.mongodb.value: MongoDBConnector, diff --git a/src/fides/api/service/connectors/base_connector.py b/src/fides/api/service/connectors/base_connector.py index 82477698e4..9990a61771 100644 --- a/src/fides/api/service/connectors/base_connector.py +++ b/src/fides/api/service/connectors/base_connector.py @@ -4,10 +4,10 @@ from sqlalchemy.orm import Session from fides.api.common_exceptions import NotSupportedForCollection -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.connectionconfig import ConnectionConfig, ConnectionTestStatus from fides.api.models.policy import Policy -from fides.api.models.privacy_request import PrivacyRequest +from fides.api.models.privacy_request import PrivacyRequest, RequestTask from fides.api.service.connectors.query_config import QueryConfig from fides.api.util.collection_util import Row from fides.config import CONFIG @@ -39,7 +39,7 @@ def __init__(self, configuration: ConnectionConfig): self.db_client: Optional[DB_CONNECTOR_TYPE] = None @abstractmethod - def query_config(self, node: TraversalNode) -> QueryConfig[Any]: + def query_config(self, node: ExecutionNode) -> QueryConfig[Any]: """Return the query config that corresponds to this connector type""" @abstractmethod @@ -63,9 +63,10 @@ def client(self) -> DB_CONNECTOR_TYPE: @abstractmethod def retrieve_data( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, input_data: Dict[str, List[Any]], ) -> List[Row]: """Retrieve data in a connector dependent way based on input data. @@ -76,11 +77,11 @@ def retrieve_data( @abstractmethod def mask_data( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, rows: List[Row], - input_data: Dict[str, List[Any]], ) -> int: """Execute a masking request. Return the number of rows that have been updated @@ -91,9 +92,10 @@ def mask_data( def run_consent_request( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, identity_data: Dict[str, Any], session: Session, ) -> bool: @@ -106,7 +108,7 @@ def run_consent_request( f"Consent requests are not supported for connectors of type {self.configuration.connection_type}" ) - def dry_run_query(self, node: TraversalNode) -> Optional[str]: + def dry_run_query(self, node: ExecutionNode) -> Optional[str]: """Generate a dry-run query to display action that will be taken""" return self.query_config(node).dry_run_query() diff --git a/src/fides/api/service/connectors/dynamodb_connector.py b/src/fides/api/service/connectors/dynamodb_connector.py index f917188645..995691f93c 100644 --- a/src/fides/api/service/connectors/dynamodb_connector.py +++ b/src/fides/api/service/connectors/dynamodb_connector.py @@ -7,10 +7,10 @@ import fides.connectors.aws as aws_connector from fides.api.common_exceptions import ConnectionException -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.connectionconfig import ConnectionTestStatus from fides.api.models.policy import Policy -from fides.api.models.privacy_request import PrivacyRequest +from fides.api.models.privacy_request import PrivacyRequest, RequestTask from fides.api.schemas.connection_configuration.connection_secrets_dynamodb import ( DynamoDBSchema, ) @@ -49,7 +49,7 @@ def create_client(self) -> Any: # type: ignore except ValueError: raise ConnectionException("Value Error connecting to AWS DynamoDB.") - def query_config(self, node: TraversalNode) -> QueryConfig[Any]: + def query_config(self, node: ExecutionNode) -> QueryConfig[Any]: """Query wrapper corresponding to the input traversal_node.""" client = self.client() try: @@ -87,9 +87,10 @@ def test_connection(self) -> Optional[ConnectionTestStatus]: def retrieve_data( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, input_data: Dict[str, List[Any]], ) -> List[Row]: """ @@ -132,11 +133,11 @@ def retrieve_data( def mask_data( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, rows: List[Row], - input_data: Dict[str, List[Any]], ) -> int: """Execute a masking requestfor DynamoDB""" diff --git a/src/fides/api/service/connectors/fides_connector.py b/src/fides/api/service/connectors/fides_connector.py index e709bb5b7f..e4027a4344 100644 --- a/src/fides/api/service/connectors/fides_connector.py +++ b/src/fides/api/service/connectors/fides_connector.py @@ -2,14 +2,14 @@ from loguru import logger as log -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.connectionconfig import ( ConnectionConfig, ConnectionTestStatus, ConnectionType, ) from fides.api.models.policy import Policy -from fides.api.models.privacy_request import PrivacyRequest +from fides.api.models.privacy_request import PrivacyRequest, RequestTask from fides.api.schemas.connection_configuration.connection_secrets_fides import ( FidesConnectorSchema, ) @@ -26,7 +26,10 @@ class FidesConnector(BaseConnector[FidesClient]): - """A connector that forwards requests to other Fides instances""" + """A connector that forwards requests to other Fides instances + + This has not been updated to work with DSR 3.0 and is assumed to break. + """ def __init__(self, configuration: ConnectionConfig): super().__init__(configuration) @@ -42,7 +45,7 @@ def __init__(self, configuration: ConnectionConfig): else DEFAULT_POLLING_INTERVAL ) - def query_config(self, node: TraversalNode) -> QueryConfig[Any]: + def query_config(self, node: ExecutionNode) -> QueryConfig[Any]: """Return the query config that corresponds to this connector type""" # no query config for fides connectors @@ -82,9 +85,10 @@ def test_connection(self) -> Optional[ConnectionTestStatus]: def retrieve_data( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, input_data: Dict[str, List[Any]], ) -> List[Row]: """Execute access request and fetch access data from remote Fides""" @@ -132,11 +136,11 @@ def retrieve_data( def mask_data( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, rows: List[Row], - input_data: Dict[str, List[Any]], ) -> int: """Execute an erasure request on remote fides""" identity_data = { diff --git a/src/fides/api/service/connectors/http_connector.py b/src/fides/api/service/connectors/http_connector.py index a2e4ac3200..5f70fe24a4 100644 --- a/src/fides/api/service/connectors/http_connector.py +++ b/src/fides/api/service/connectors/http_connector.py @@ -6,10 +6,10 @@ from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR from fides.api.common_exceptions import ClientUnsuccessfulException -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.connectionconfig import ConnectionTestStatus from fides.api.models.policy import Policy -from fides.api.models.privacy_request import PrivacyRequest +from fides.api.models.privacy_request import PrivacyRequest, RequestTask from fides.api.schemas.connection_configuration import HttpsSchema from fides.api.service.connectors.base_connector import BaseConnector from fides.api.service.connectors.query_config import QueryConfig @@ -66,14 +66,15 @@ def test_connection(self) -> Optional[ConnectionTestStatus]: """ return ConnectionTestStatus.skipped - def query_config(self, node: TraversalNode) -> QueryConfig[Any]: + def query_config(self, node: ExecutionNode) -> QueryConfig[Any]: """Return the query config that corresponds to this connector type""" def retrieve_data( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, input_data: Dict[str, List[Any]], ) -> List[Row]: """Currently not supported as webhooks are not called at the collection level""" @@ -83,11 +84,11 @@ def retrieve_data( def mask_data( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, rows: List[Row], - input_data: Dict[str, List[Any]], ) -> int: """Currently not supported as webhooks are not called at the collection level""" raise NotImplementedError( diff --git a/src/fides/api/service/connectors/manual_connector.py b/src/fides/api/service/connectors/manual_connector.py deleted file mode 100644 index d36b0a38ba..0000000000 --- a/src/fides/api/service/connectors/manual_connector.py +++ /dev/null @@ -1,106 +0,0 @@ -from typing import Any, Dict, List, Optional - -from fides.api.common_exceptions import PrivacyRequestPaused -from fides.api.graph.traversal import TraversalNode -from fides.api.models.policy import CurrentStep, Policy -from fides.api.models.privacy_request import ManualAction, PrivacyRequest -from fides.api.service.connectors.base_connector import BaseConnector -from fides.api.service.connectors.query_config import ManualQueryConfig -from fides.api.util.collection_util import Row - - -class ManualConnector(BaseConnector[None]): - def query_config(self, node: TraversalNode) -> ManualQueryConfig: - """ - The ManualQueryConfig generates instructions for the user to retrieve and mask - data manually. - """ - return ManualQueryConfig(node) - - def create_client(self) -> None: - """Not needed because this connector involves a human performing some lookup step""" - return None - - def close(self) -> None: - """No session to close for the Manual Connector""" - return None - - def test_connection(self) -> None: - """No automated test_connection available for the Manual Connector""" - return None - - def retrieve_data( # type: ignore - self, - node: TraversalNode, - policy: Policy, - privacy_request: PrivacyRequest, - input_data: Dict[str, List[Any]], - ) -> Optional[List[Row]]: - """ - Returns manually added data for the given collection if it exists, otherwise pauses the Privacy Request. - - On the event that we pause, caches the stopped step, stopped collection, and details needed to manually resume - the privacy request. - """ - cached_results: Optional[List[Row]] = privacy_request.get_manual_access_input( - node.address - ) - - if cached_results is not None: # None comparison intentional - privacy_request.cache_paused_collection_details() # Resets paused details to None - return cached_results - - query_config = self.query_config(node) - action_needed: Optional[ManualAction] = query_config.generate_query( - input_data, policy - ) - privacy_request.cache_paused_collection_details( - step=CurrentStep.access, - collection=node.address, - action_needed=[action_needed] if action_needed else None, - ) - - raise PrivacyRequestPaused( - f"Collection '{node.address.value}' waiting on manual data for privacy request '{privacy_request.id}'" - ) - - def mask_data( # type: ignore - self, - node: TraversalNode, - policy: Policy, - privacy_request: PrivacyRequest, - rows: List[Row], - input_data: Dict[str, List[Any]], - ) -> Optional[int]: - """If erasure confirmation has been added to the manual cache, continue, otherwise, - pause and wait for manual input. - - On the event that we pause, caches the stopped step, stopped collection, and details needed to manually resume - the privacy request. - """ - manual_cached_count: Optional[int] = privacy_request.get_manual_erasure_count( - node.address - ) - - if manual_cached_count is not None: # None comparison intentional - privacy_request.cache_paused_collection_details() # Resets paused details to None - return manual_cached_count - - query_config: ManualQueryConfig = self.query_config(node) - action_needed: List[ManualAction] = [] - for row in rows: - action: Optional[ManualAction] = query_config.generate_update_stmt( - row, policy, privacy_request - ) - if action: - action_needed.append(action) - - privacy_request.cache_paused_collection_details( - step=CurrentStep.erasure, - collection=node.address, - action_needed=action_needed if action_needed else None, - ) - - raise PrivacyRequestPaused( - f"Collection '{node.address.value}' waiting on manual erasure confirmation for privacy request '{privacy_request.id}'" - ) diff --git a/src/fides/api/service/connectors/manual_webhook_connector.py b/src/fides/api/service/connectors/manual_webhook_connector.py index 06efdda600..5deb6b2abf 100644 --- a/src/fides/api/service/connectors/manual_webhook_connector.py +++ b/src/fides/api/service/connectors/manual_webhook_connector.py @@ -1,15 +1,15 @@ from typing import Any, Dict, List -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.connectionconfig import ConnectionConfig, ConnectionTestStatus from fides.api.models.policy import Policy -from fides.api.models.privacy_request import PrivacyRequest +from fides.api.models.privacy_request import PrivacyRequest, RequestTask from fides.api.service.connectors.base_connector import BaseConnector from fides.api.util.collection_util import Row class ManualWebhookConnector(BaseConnector[None]): - def query_config(self, node: TraversalNode) -> None: # type: ignore + def query_config(self, node: ExecutionNode) -> None: # type: ignore """ Not applicable for this connector type. Manual Webhooks are not run as part of the traversal. There will not be a node associated with the ManualWebhook. @@ -35,9 +35,10 @@ def test_connection(self) -> ConnectionTestStatus: def retrieve_data( # type: ignore self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, input_data: Dict[str, List[Any]], ) -> None: """ @@ -47,11 +48,11 @@ def retrieve_data( # type: ignore def mask_data( # type: ignore self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, rows: List[Row], - input_data: Dict[str, List[Any]], ) -> None: """ Not applicable for a manual webhook. Manual webhooks are not called as part of the traversal. diff --git a/src/fides/api/service/connectors/mongodb_connector.py b/src/fides/api/service/connectors/mongodb_connector.py index 655a08c515..141536b952 100644 --- a/src/fides/api/service/connectors/mongodb_connector.py +++ b/src/fides/api/service/connectors/mongodb_connector.py @@ -5,10 +5,10 @@ from pymongo.errors import OperationFailure, ServerSelectionTimeoutError from fides.api.common_exceptions import ConnectionException -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.connectionconfig import ConnectionTestStatus from fides.api.models.policy import Policy -from fides.api.models.privacy_request import PrivacyRequest +from fides.api.models.privacy_request import PrivacyRequest, RequestTask from fides.api.schemas.connection_configuration.connection_secrets_mongodb import ( MongoDBSchema, ) @@ -48,7 +48,7 @@ def create_client(self) -> MongoClient: except ValueError: raise ConnectionException("Value Error connecting to MongoDB.") - def query_config(self, node: TraversalNode) -> QueryConfig[Any]: + def query_config(self, node: ExecutionNode) -> QueryConfig[Any]: """Query wrapper corresponding to the input traversal_node.""" return MongoQueryConfig(node) @@ -87,9 +87,10 @@ def test_connection(self) -> Optional[ConnectionTestStatus]: def retrieve_data( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, input_data: Dict[str, List[Any]], ) -> List[Row]: """Retrieve mongo data""" @@ -115,11 +116,11 @@ def retrieve_data( def mask_data( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, rows: List[Row], - input_data: Dict[str, List[Any]], ) -> int: """Execute a masking request""" query_config = self.query_config(node) diff --git a/src/fides/api/service/connectors/query_config.py b/src/fides/api/service/connectors/query_config.py index 0f68beca4d..cbd486d4a3 100644 --- a/src/fides/api/service/connectors/query_config.py +++ b/src/fides/api/service/connectors/query_config.py @@ -17,7 +17,7 @@ FieldPath, MaskingOverride, ) -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.policy import Policy, Rule from fides.api.models.privacy_request import ManualAction, PrivacyRequest from fides.api.schemas.policy import ActionType @@ -40,16 +40,16 @@ class QueryConfig(Generic[T], ABC): """A wrapper around a resource-type dependent query object that can generate runnable queries and string representations.""" - def __init__(self, node: TraversalNode): + def __init__(self, node: ExecutionNode): self.node = node def field_map(self) -> Dict[FieldPath, Field]: """Flattened FieldPaths of interest from this traversal_node.""" - return self.node.node.collection.field_dict + return self.node.collection.field_dict def top_level_field_map(self) -> Dict[FieldPath, Field]: """Top level FieldPaths on this traversal_node.""" - return self.node.node.collection.top_level_field_dict + return self.node.collection.top_level_field_dict def build_rule_target_field_paths( self, policy: Policy @@ -70,7 +70,7 @@ def build_rule_target_field_paths( targeted_field_paths = [] collection_categories: Dict[ str, List[FieldPath] - ] = self.node.node.collection.field_paths_by_category # type: ignore + ] = self.node.collection.field_paths_by_category # type: ignore for rule_cat in rule_categories: for collection_cat, field_paths in collection_categories.items(): if collection_cat.startswith(rule_cat): @@ -96,7 +96,7 @@ def query_sources(self) -> Dict[str, List[CollectionAddress]]: Translate keys from field paths to string values """ data: Dict[str, List[CollectionAddress]] = {} - for edge in self.node.incoming_edges(): + for edge in self.node.incoming_edges: append(data, edge.f2.field_path.string_path, edge.f1.collection_address()) return data @@ -282,7 +282,7 @@ def generate_query( locators: Dict[str, Any] = self.node.typed_filtered_values(input_data) get: List[str] = [ field_path.string_path - for field_path in self.node.node.collection.top_level_field_dict + for field_path in self.node.collection.top_level_field_dict ] if get and locators: @@ -364,7 +364,7 @@ def get_formatted_query_string( clauses: List[str], ) -> str: """Returns an SQL query string.""" - return f"SELECT {field_list} FROM {self.node.node.collection.name} WHERE {' OR '.join(clauses)}" + return f"SELECT {field_list} FROM {self.node.collection.name} WHERE {' OR '.join(clauses)}" def get_formatted_update_stmt( self, @@ -592,7 +592,7 @@ def get_formatted_query_string( clauses: List[str], ) -> str: """Returns a query string with double quotation mark formatting as required by Snowflake syntax.""" - return f'SELECT {field_list} FROM "{self.node.node.collection.name}" WHERE {" OR ".join(clauses)}' + return f'SELECT {field_list} FROM "{self.node.collection.name}" WHERE {" OR ".join(clauses)}' def format_key_map_for_update_stmt(self, fields: List[str]) -> List[str]: """Adds the appropriate formatting for update statements in this datastore.""" @@ -618,7 +618,7 @@ def get_formatted_query_string( ) -> str: """Returns a query string with double quotation mark formatting for tables that have the same names as Redshift reserved words.""" - return f'SELECT {field_list} FROM "{self.node.node.collection.name}" WHERE {" OR ".join(clauses)}' + return f'SELECT {field_list} FROM "{self.node.collection.name}" WHERE {" OR ".join(clauses)}' class BigQueryQueryConfig(QueryStringWithoutTuplesOverrideQueryConfig): @@ -633,7 +633,7 @@ def get_formatted_query_string( ) -> str: """Returns a query string with backtick formatting for tables that have the same names as BigQuery reserved words.""" - return f'SELECT {field_list} FROM `{self.node.node.collection.name}` WHERE {" OR ".join(clauses)}' + return f'SELECT {field_list} FROM `{self.node.collection.name}` WHERE {" OR ".join(clauses)}' def generate_update( self, row: Row, policy: Policy, request: PrivacyRequest, client: Engine @@ -782,7 +782,7 @@ def dry_run_query(self) -> Optional[str]: class DynamoDBQueryConfig(QueryConfig[DynamoDBStatement]): def __init__( - self, node: TraversalNode, attribute_definitions: List[Dict[str, Any]] + self, node: ExecutionNode, attribute_definitions: List[Dict[str, Any]] ): super().__init__(node) self.attribute_definitions = attribute_definitions diff --git a/src/fides/api/service/connectors/saas_connector.py b/src/fides/api/service/connectors/saas_connector.py index 195f97b73f..d5ad20a128 100644 --- a/src/fides/api/service/connectors/saas_connector.py +++ b/src/fides/api/service/connectors/saas_connector.py @@ -11,10 +11,10 @@ PostProcessingException, SkippingConsentPropagation, ) -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.connectionconfig import ConnectionConfig, ConnectionTestStatus from fides.api.models.policy import Policy -from fides.api.models.privacy_request import PrivacyRequest +from fides.api.models.privacy_request import PrivacyRequest, RequestTask from fides.api.schemas.limiter.rate_limit_config import RateLimitConfig from fides.api.schemas.policy import ActionType from fides.api.schemas.saas.saas_config import ( @@ -77,7 +77,7 @@ def __init__(self, configuration: ConnectionConfig): self.current_privacy_request: Optional[PrivacyRequest] = None self.current_saas_request: Optional[SaaSRequest] = None - def query_config(self, node: TraversalNode) -> SaaSQueryConfig: + def query_config(self, node: ExecutionNode) -> SaaSQueryConfig: """ Returns the query config for a given node which includes the endpoints and connector param values for the current collection. @@ -117,7 +117,7 @@ def get_rate_limit_config(self) -> Optional[RateLimitConfig]: ) def set_privacy_request_state( - self, privacy_request: PrivacyRequest, node: TraversalNode + self, privacy_request: PrivacyRequest, node: ExecutionNode ) -> None: """ Sets the class state for the current privacy request @@ -175,9 +175,10 @@ def create_client(self) -> AuthenticatedClient: @log_context(action_type=ActionType.access.value) def retrieve_data( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, input_data: Dict[str, List[Any]], ) -> List[Row]: """Retrieve data from SaaS APIs""" @@ -391,11 +392,11 @@ def process_response_data( @log_context(action_type=ActionType.erasure.value) def mask_data( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, rows: List[Row], - input_data: Dict[str, List[Any]], ) -> int: """Execute a masking request. Return the number of rows that have been updated.""" self.set_privacy_request_state(privacy_request, node) @@ -478,9 +479,10 @@ def relevant_consent_identities( @log_context(action_type=ActionType.consent.value) def run_consent_request( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, identity_data: Dict[str, Any], session: Session, ) -> bool: @@ -611,7 +613,7 @@ def _invoke_read_request_override( client: AuthenticatedClient, policy: Policy, privacy_request: PrivacyRequest, - node: TraversalNode, + node: ExecutionNode, input_data: Dict[str, List], secrets: Any, ) -> List[Row]: diff --git a/src/fides/api/service/connectors/saas_query_config.py b/src/fides/api/service/connectors/saas_query_config.py index 1eb557050a..550dc21ed4 100644 --- a/src/fides/api/service/connectors/saas_query_config.py +++ b/src/fides/api/service/connectors/saas_query_config.py @@ -11,7 +11,7 @@ from fides.api.common_exceptions import FidesopsException from fides.api.graph.config import ScalarField -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.policy import Policy from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.saas.saas_config import Endpoint, SaaSConfig, SaaSRequest @@ -40,7 +40,7 @@ class SaaSQueryConfig(QueryConfig[SaaSRequestParams]): def __init__( self, - node: TraversalNode, + node: ExecutionNode, endpoints: Dict[str, Endpoint], secrets: Dict[str, Any], data_protection_request: Optional[SaaSRequest] = None, @@ -283,7 +283,7 @@ def generate_query( if not self.current_request: raise FidesopsException( f"The 'read' action is not defined for the '{self.collection_name}' " - f"endpoint in {self.node.node.dataset.connection_key}" + f"endpoint in {self.node.connection_key}" ) # create the source of param values to populate the various placeholders diff --git a/src/fides/api/service/connectors/sql_connector.py b/src/fides/api/service/connectors/sql_connector.py index b8c2c15227..0977c40a1a 100644 --- a/src/fides/api/service/connectors/sql_connector.py +++ b/src/fides/api/service/connectors/sql_connector.py @@ -24,10 +24,10 @@ ConnectionException, SSHTunnelConfigNotFoundException, ) -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.connectionconfig import ConnectionConfig, ConnectionTestStatus from fides.api.models.policy import Policy -from fides.api.models.privacy_request import PrivacyRequest +from fides.api.models.privacy_request import PrivacyRequest, RequestTask from fides.api.schemas.connection_configuration import ( ConnectionConfigSecretsSchema, MicrosoftSQLServerSchema, @@ -103,8 +103,8 @@ def default_cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]: def build_uri(self) -> str: """Build a database specific uri connection string""" - def query_config(self, node: TraversalNode) -> SQLQueryConfig: - """Query wrapper corresponding to the input traversal_node.""" + def query_config(self, node: ExecutionNode) -> SQLQueryConfig: + """Query wrapper corresponding to the input execution_node.""" return SQLQueryConfig(node) def test_connection(self) -> Optional[ConnectionTestStatus]: @@ -130,9 +130,10 @@ def test_connection(self) -> Optional[ConnectionTestStatus]: def retrieve_data( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, input_data: Dict[str, List[Any]], ) -> List[Row]: """Retrieve sql data""" @@ -149,11 +150,11 @@ def retrieve_data( def mask_data( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, rows: List[Row], - input_data: Dict[str, List[Any]], ) -> int: """Execute a masking request. Returns the number of records masked""" query_config = self.query_config(node) @@ -438,8 +439,8 @@ def set_schema(self, connection: Connection) -> None: connection.execute(stmt) # Overrides SQLConnector.query_config - def query_config(self, node: TraversalNode) -> RedshiftQueryConfig: - """Query wrapper corresponding to the input traversal_node.""" + def query_config(self, node: ExecutionNode) -> RedshiftQueryConfig: + """Query wrapper corresponding to the input execution node.""" return RedshiftQueryConfig(node) @@ -476,17 +477,17 @@ def create_client(self) -> Engine: ) # Overrides SQLConnector.query_config - def query_config(self, node: TraversalNode) -> BigQueryQueryConfig: - """Query wrapper corresponding to the input traversal_node.""" + def query_config(self, node: ExecutionNode) -> BigQueryQueryConfig: + """Query wrapper corresponding to the input execution_node.""" return BigQueryQueryConfig(node) def mask_data( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, rows: List[Row], - input_data: Dict[str, List[Any]], ) -> int: """Execute a masking request. Returns the number of records masked""" query_config = self.query_config(node) @@ -534,8 +535,8 @@ def build_uri(self) -> str: url: str = Snowflake_URL(**kwargs) return url - def query_config(self, node: TraversalNode) -> SQLQueryConfig: - """Query wrapper corresponding to the input traversal_node.""" + def query_config(self, node: ExecutionNode) -> SQLQueryConfig: + """Query wrapper corresponding to the input execution_node.""" return SnowflakeQueryConfig(node) @@ -566,8 +567,8 @@ def build_uri(self) -> URL: return url - def query_config(self, node: TraversalNode) -> SQLQueryConfig: - """Query wrapper corresponding to the input traversal_node.""" + def query_config(self, node: ExecutionNode) -> SQLQueryConfig: + """Query wrapper corresponding to the input execution_node.""" return MicrosoftSQLServerQueryConfig(node) @staticmethod diff --git a/src/fides/api/service/privacy_request/request_runner_service.py b/src/fides/api/service/privacy_request/request_runner_service.py index 0f8121de40..732132186e 100644 --- a/src/fides/api/service/privacy_request/request_runner_service.py +++ b/src/fides/api/service/privacy_request/request_runner_service.py @@ -5,7 +5,6 @@ import requests from loguru import logger from pydantic import ValidationError -from redis.exceptions import DataError from sqlalchemy.orm import Query, Session from fides.api import common_exceptions @@ -15,21 +14,14 @@ ManualWebhookFieldsUnset, MessageDispatchException, NoCachedManualWebhookEntry, + PrivacyRequestExit, PrivacyRequestPaused, ) from fides.api.db.session import get_db_session -from fides.api.graph.analytics_events import ( - failed_graph_analytics_event, - fideslog_graph_failure, -) -from fides.api.graph.config import CollectionAddress, GraphDataset +from fides.api.graph.config import CollectionAddress from fides.api.graph.graph import DatasetGraph from fides.api.models.audit_log import AuditLog, AuditLogAction -from fides.api.models.connectionconfig import ( - AccessLevel, - ConnectionConfig, - ConnectionType, -) +from fides.api.models.connectionconfig import AccessLevel, ConnectionConfig from fides.api.models.datasetconfig import DatasetConfig from fides.api.models.manual_webhook import AccessManualWebhook from fides.api.models.policy import ( @@ -63,22 +55,17 @@ from fides.api.service.messaging.message_dispatch_service import dispatch_message from fides.api.service.storage.storage_uploader_service import upload from fides.api.task.filter_results import filter_data_categories +from fides.api.task.graph_runners import access_runner, consent_runner, erasure_runner from fides.api.task.graph_task import ( + build_consent_dataset_graph, + filter_by_enabled_actions, get_cached_data_for_erasures, - run_access_request, - run_consent_request, - run_erasure, ) from fides.api.tasks import DatabaseTask, celery_app from fides.api.tasks.scheduled.scheduler import scheduler -from fides.api.util.cache import ( - FidesopsRedis, - get_async_task_tracking_cache_key, - get_cache, -) +from fides.api.util.cache import cache_task_tracking_key from fides.api.util.collection_util import Row from fides.api.util.logger import Pii, _log_exception, _log_warning -from fides.api.util.wrappers import sync from fides.common.api.v1.urn_registry import ( PRIVACY_REQUEST_TRANSFER_TO_PARENT, V1_URL_PREFIX, @@ -226,6 +213,7 @@ def upload_access_results( # pylint: disable=R0912 if not access_result: logger.info("No results returned for access request {}", privacy_request.id) + rule_filtered_results: Dict[str, Dict[str, List[Row]]] = {} for rule in policy.get_rules_for_action( # pylint: disable=R1702 action_type=ActionType.access ): @@ -245,6 +233,7 @@ def upload_access_results( # pylint: disable=R0912 filtered_results.update( manual_data ) # Add manual data directly to each upload packet + rule_filtered_results[rule.key] = filtered_results logger.info( "Starting access request upload for rule {} for privacy request {}", @@ -271,7 +260,12 @@ def upload_access_results( # pylint: disable=R0912 Pii(str(exc)), ) privacy_request.status = PrivacyRequestStatus.error - + # Save the results we uploaded to the user for later retrieval + privacy_request.save_filtered_access_results(session, rule_filtered_results) + # Saving access request URL's on the privacy request in case DSR 3.0 + # exits processing before the email is sent + privacy_request.access_result_urls = {"access_result_urls": download_urls} + privacy_request.save(session) return download_urls @@ -280,29 +274,21 @@ def queue_privacy_request( from_webhook_id: Optional[str] = None, from_step: Optional[str] = None, ) -> str: - cache: FidesopsRedis = get_cache() - logger.info("queueing privacy request") + logger.info( + "Queueing privacy request {} from step {}", privacy_request_id, from_step + ) task = run_privacy_request.delay( privacy_request_id=privacy_request_id, from_webhook_id=from_webhook_id, from_step=from_step, ) - try: - cache.set( - get_async_task_tracking_cache_key(privacy_request_id), - task.task_id, - ) - except DataError: - logger.debug( - "Error tracking task_id for request with id {}", privacy_request_id - ) + cache_task_tracking_key(privacy_request_id, task.task_id) return task.task_id @celery_app.task(base=DatabaseTask, bind=True) -@sync -async def run_privacy_request( +def run_privacy_request( self: DatabaseTask, privacy_request_id: str, from_webhook_id: Optional[str] = None, @@ -330,8 +316,6 @@ async def run_privacy_request( f"Privacy request with id {privacy_request_id} not found" ) - privacy_request.cache_failed_checkpoint_details() # Reset failed step and collection to None - if privacy_request.status == PrivacyRequestStatus.canceled: logger.info( "Terminating privacy request {}: request canceled.", privacy_request.id @@ -356,9 +340,11 @@ async def run_privacy_request( if not manual_webhook_erasure_results.proceed: return + # Pre-Webhooks CHECKPOINT if can_run_checkpoint( request_checkpoint=CurrentStep.pre_webhooks, from_checkpoint=resume_step ): + privacy_request.cache_failed_checkpoint_details(CurrentStep.pre_webhooks) # Run pre-execution webhooks proceed = run_webhooks_and_report_status( session, @@ -387,39 +373,62 @@ async def run_privacy_request( fides_connector_datasets: Set[str] = filter_fides_connector_datasets( connection_configs ) - access_result_urls: List[str] = [] + # Access CHECKPOINT if ( policy.get_rules_for_action(action_type=ActionType.access) or policy.get_rules_for_action(action_type=ActionType.erasure) ) and can_run_checkpoint( request_checkpoint=CurrentStep.access, from_checkpoint=resume_step ): - access_result: Dict[str, List[Row]] = await run_access_request( + privacy_request.cache_failed_checkpoint_details(CurrentStep.access) + access_runner( privacy_request=privacy_request, policy=policy, graph=dataset_graph, connection_configs=connection_configs, identity=identity_data, session=session, + privacy_request_proceed=True, # Should always be True unless we're testing + ) + + # Upload Access Results CHECKPOINT + access_result_urls: List[str] = [] + raw_access_results: Dict = privacy_request.get_raw_access_results() + if ( + policy.get_rules_for_action(action_type=ActionType.access) + or policy.get_rules_for_action( + action_type=ActionType.erasure + ) # Intentional to support requeuing the Privacy Request after the Access step for DSR 3.0 for both access/erasure requests + ) and can_run_checkpoint( + request_checkpoint=CurrentStep.upload_access, + from_checkpoint=resume_step, + ): + privacy_request.cache_failed_checkpoint_details( + CurrentStep.upload_access + ) + filtered_access_results = filter_by_enabled_actions( + raw_access_results, connection_configs ) access_result_urls = upload_access_results( session, policy, - access_result, + filtered_access_results, dataset_graph, privacy_request, manual_webhook_access_results.manual_data, fides_connector_datasets, ) + # Erasure CHECKPOINT if policy.get_rules_for_action( action_type=ActionType.erasure ) and can_run_checkpoint( request_checkpoint=CurrentStep.erasure, from_checkpoint=resume_step ): + privacy_request.cache_failed_checkpoint_details(CurrentStep.erasure) # We only need to run the erasure once until masking strategies are handled - await run_erasure( + erasure_runner( privacy_request=privacy_request, policy=policy, graph=dataset_graph, @@ -429,6 +438,18 @@ async def run_privacy_request( privacy_request.id ), session=session, + privacy_request_proceed=True, # Should always be True unless we're testing + ) + + # Finalize Erasure CHECKPOINT + if can_run_checkpoint( + request_checkpoint=CurrentStep.finalize_erasure, + from_checkpoint=resume_step, + ): + # This checkpoint allows a Privacy Request to be re-queued + # after the Erasure Step is complete for DSR 3.0 + privacy_request.cache_failed_checkpoint_details( + CurrentStep.finalize_erasure ) if policy.get_rules_for_action( @@ -437,13 +458,26 @@ async def run_privacy_request( request_checkpoint=CurrentStep.consent, from_checkpoint=resume_step, ): - await run_consent_request( + privacy_request.cache_failed_checkpoint_details(CurrentStep.consent) + consent_runner( privacy_request=privacy_request, policy=policy, graph=build_consent_dataset_graph(datasets), connection_configs=connection_configs, identity=identity_data, session=session, + privacy_request_proceed=True, # Should always be True unless we're testing + ) + + # Finalize Consent CHECKPOINT + if can_run_checkpoint( + request_checkpoint=CurrentStep.finalize_consent, + from_checkpoint=resume_step, + ): + # This checkpoint allows a Privacy Request to be re-queued + # after the Consent Step is complete for DSR 3.0 + privacy_request.cache_failed_checkpoint_details( + CurrentStep.finalize_consent ) except PrivacyRequestPaused as exc: @@ -451,17 +485,21 @@ async def run_privacy_request( _log_warning(exc, CONFIG.dev_mode) return + except PrivacyRequestExit: + # Privacy Request Exiting awaiting sub task processing (Request Tasks) + # The access, consent, and erasure runners for DSR 3.0 throw this exception after its + # Request Tasks have been built. The Privacy Request will be requeued from + # the appropriate checkpoint when all the Request Tasks have run. + return + except BaseException as exc: # pylint: disable=broad-except privacy_request.error_processing(db=session) - # Send analytics to Fideslog - await fideslog_graph_failure( - failed_graph_analytics_event(privacy_request, exc) - ) # If dev mode, log traceback _log_exception(exc, CONFIG.dev_mode) return # Check if privacy request needs erasure or consent emails sent + # Email post-send CHECKPOINT if ( ( policy.get_rules_for_action(action_type=ActionType.erasure) @@ -473,6 +511,7 @@ async def run_privacy_request( ) and needs_batch_email_send(session, identity_data, privacy_request) ): + privacy_request.cache_failed_checkpoint_details(CurrentStep.email_post_send) privacy_request.pause_processing_for_email_send(session) logger.info( "Privacy request '{}' exiting: awaiting email send.", @@ -480,11 +519,12 @@ async def run_privacy_request( ) return - # Run post-execution webhooks + # Post Webhooks CHECKPOINT if can_run_checkpoint( request_checkpoint=CurrentStep.post_webhooks, from_checkpoint=resume_step, ): + privacy_request.cache_failed_checkpoint_details(CurrentStep.post_webhooks) proceed = run_webhooks_and_report_status( db=session, privacy_request=privacy_request, @@ -499,15 +539,19 @@ async def run_privacy_request( action_type=ActionType.consent ): try: + if not access_result_urls: + # For DSR 3.0, if the request had both access and erasure rules, this needs to be fetched + # from the database because the Privacy Request would have exited + # processing and lost access to the access_result_urls in memory + access_result_urls = (privacy_request.access_result_urls or {}).get( + "access_result_urls", [] + ) initiate_privacy_request_completion_email( session, policy, access_result_urls, identity_data ) except (IdentityNotFoundException, MessageDispatchException) as e: privacy_request.error_processing(db=session) # If dev mode, log traceback - await fideslog_graph_failure( - failed_graph_analytics_event(privacy_request, e) - ) _log_exception(e, CONFIG.dev_mode) return privacy_request.finished_processing_at = datetime.utcnow() @@ -525,31 +569,6 @@ async def run_privacy_request( privacy_request.save(db=session) -def build_consent_dataset_graph(datasets: List[DatasetConfig]) -> DatasetGraph: - """ - Build the starting DatasetGraph for consent requests. - - Consent Graph has one node per dataset. Nodes must be of saas type and have consent requests defined. - """ - consent_datasets: List[GraphDataset] = [] - - for dataset_config in datasets: - connection_type: ConnectionType = ( - dataset_config.connection_config.connection_type # type: ignore - ) - saas_config: Optional[Dict] = dataset_config.connection_config.saas_config - if ( - connection_type == ConnectionType.saas - and saas_config - and saas_config.get("consent_requests") - ): - consent_datasets.append( - dataset_config.get_dataset_with_stubbed_collection() # type: ignore[arg-type, assignment] - ) - - return DatasetGraph(*consent_datasets) - - def initiate_privacy_request_completion_email( session: Session, policy: Policy, diff --git a/src/fides/api/service/privacy_request/request_service.py b/src/fides/api/service/privacy_request/request_service.py index bd725d6a2f..e00a7d50fa 100644 --- a/src/fides/api/service/privacy_request/request_service.py +++ b/src/fides/api/service/privacy_request/request_service.py @@ -1,22 +1,36 @@ from __future__ import annotations from asyncio import sleep -from datetime import datetime +from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Set from httpx import AsyncClient from loguru import logger +from sqlalchemy import text +from sqlalchemy.orm import Query +from sqlalchemy.sql.elements import TextClause from fides.api.common_exceptions import PrivacyRequestNotFound from fides.api.models.policy import Policy -from fides.api.models.privacy_request import PrivacyRequest, PrivacyRequestStatus +from fides.api.models.privacy_request import ( + EXITED_EXECUTION_LOG_STATUSES, + ExecutionLogStatus, + PrivacyRequest, + PrivacyRequestStatus, +) from fides.api.schemas.drp_privacy_request import DrpPrivacyRequestCreate from fides.api.schemas.masking.masking_secrets import MaskingSecretCache from fides.api.schemas.policy import ActionType from fides.api.schemas.privacy_request import PrivacyRequestResponse from fides.api.schemas.redis_cache import Identity from fides.api.service.masking.strategy.masking_strategy import MaskingStrategy +from fides.api.tasks import DatabaseTask, celery_app +from fides.api.tasks.scheduled.scheduler import scheduler from fides.common.api.v1.urn_registry import PRIVACY_REQUESTS, V1_URL_PREFIX +from fides.config import CONFIG + +PRIVACY_REQUEST_STATUS_CHANGE_POLL = "privacy_request_status_change_poll" +DSR_DATA_REMOVAL = "dsr_data_removal" def build_required_privacy_request_kwargs( @@ -137,3 +151,159 @@ async def poll_server_for_completion( raise TimeoutError( f"Timeout of {timeout_seconds} seconds has been exceeded while waiting for privacy request {privacy_request_id}" ) + + +def initiate_poll_for_exited_privacy_request_tasks() -> None: + """Initiates scheduler to check if a Privacy Request's status needs to be flipped when all + Request Tasks have had a chance to run""" + + if CONFIG.test_mode: + return + + assert ( + scheduler.running + ), "Scheduler is not running! Cannot add Privacy Request Status Change job." + + logger.info("Initiating scheduler for Privacy Request Status Change") + scheduler.add_job( + func=poll_for_exited_privacy_request_tasks, + trigger="interval", + kwargs={}, + id=PRIVACY_REQUEST_STATUS_CHANGE_POLL, + coalesce=True, + replace_existing=True, + seconds=CONFIG.execution.state_polling_interval, + ) + + +@celery_app.task(base=DatabaseTask, bind=True) +def poll_for_exited_privacy_request_tasks(self: DatabaseTask) -> Set[str]: + """ + Mark a privacy request as errored if all of its Request Tasks have run but some have errored. + + When a Request Task fails, it marks itself and *every Request Task that can be reached by the current + Request Task* as failed. However, other Request Tasks independent of this path should still have an + opportunity to run. We wait until everything has run before marking the Privacy Request as errored so it + can be reprocessed. + """ + with self.get_new_session() as db: + logger.info("Polling for privacy requests awaiting status change") + in_progress_privacy_requests = ( + db.query(PrivacyRequest) + .filter(PrivacyRequest.status == PrivacyRequestStatus.in_processing) + .order_by(PrivacyRequest.created_at) + ) + + def some_errored(tasks: Query) -> bool: + """All statuses have exited and at least one is errored""" + statuses: List[ExecutionLogStatus] = [tsk.status for tsk in tasks] + all_exited = all( + status in EXITED_EXECUTION_LOG_STATUSES for status in statuses + ) + return all_exited and ExecutionLogStatus.error in statuses + + marked_as_errored: Set[str] = set() + for pr in in_progress_privacy_requests.all(): + if pr.consent_tasks.count(): + # Consent propagation tasks - these are not created until access and erasure steps are complete. + if some_errored(pr.consent_tasks): + logger.info(f"Marking consent step of {pr.id} as error") + pr.error_processing(db) + marked_as_errored.add(pr.id) + + if pr.erasure_tasks.count(): + # Erasure tasks are created at the same time as access tasks but if any are errored, this means + # we made it to the erasure section + if some_errored(pr.erasure_tasks): + logger.info(f"Marking erasure step of {pr.id} as error") + pr.error_processing(db) + marked_as_errored.add(pr.id) + + if pr.access_tasks.count(): + if some_errored(pr.access_tasks): + logger.info(f"Marking access step of {pr.id} as error") + pr.error_processing(db) + marked_as_errored.add(pr.id) + + return marked_as_errored + + +def initiate_scheduled_dsr_data_removal() -> None: + """Initiates scheduler to cleanup obsolete access and erasure data""" + + if CONFIG.test_mode: + return + + assert ( + scheduler.running + ), "Scheduler is not running! Cannot add DSR data removal job." + + logger.info("Initiating scheduler for DSR Data Removal") + scheduler.add_job( + func=remove_saved_dsr_data, + kwargs={}, + id=DSR_DATA_REMOVAL, + coalesce=False, + replace_existing=True, + trigger="cron", + minute="0", + hour="2", + day="*", + timezone="US/Eastern", + ) + + +@celery_app.task(base=DatabaseTask, bind=True) +def remove_saved_dsr_data(self: DatabaseTask) -> None: + """ + Remove saved customer data that is no longer needed to facilitate running the access or erasure request. + """ + with self.get_new_session() as db: + logger.info("Running DSR Data Removal Task to cleanup obsolete user data") + + # Remove old request tasks which potentially contain encrypted PII + remove_dsr_data: TextClause = text( + """ + DELETE FROM requesttask + USING privacyrequest + WHERE requesttask.privacy_request_id = privacyrequest.id + AND requesttask.created_at < :ttl + AND privacyrequest.status = 'complete'; + """ + ) + + result = db.execute( + remove_dsr_data, + { + "ttl": ( + datetime.now() + - timedelta(seconds=CONFIG.execution.request_task_ttl) + ), + }, + ) + affected_rows = result.rowcount + logger.info( + f"Deleted {affected_rows} expired request tasks via DSR Data Removal Task." + ) + + # Remove columns from old privacyrequests that potentially contain encrypted PII + # or URL's that contain encrypted PII. + remove_data_from_privacy_request: TextClause = text( + """ + UPDATE privacyrequest + SET filtered_final_upload = null, access_result_urls = null + WHERE privacyrequest.updated_at < :ttl + AND privacyrequest.status = 'complete'; + """ + ) + + db.execute( + remove_data_from_privacy_request, + { + "ttl": ( # Using Redis Default TTL Seconds by default + datetime.now() - timedelta(seconds=CONFIG.redis.default_ttl_seconds) + ), + }, + ) + + db.commit() diff --git a/src/fides/api/service/saas_request/override_implementations/adyen_request_overrides.py b/src/fides/api/service/saas_request/override_implementations/adyen_request_overrides.py index 78d9f8642c..4f86bf64a9 100644 --- a/src/fides/api/service/saas_request/override_implementations/adyen_request_overrides.py +++ b/src/fides/api/service/saas_request/override_implementations/adyen_request_overrides.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.policy import Policy from fides.api.models.privacy_request import PrivacyRequest from fides.api.service.connectors.saas.authenticated_client import AuthenticatedClient @@ -19,7 +19,7 @@ @register("adyen_user_read", [SaaSRequestType.READ]) def adyen_user_read( client: AuthenticatedClient, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, input_data: Dict[str, List[Any]], diff --git a/src/fides/api/service/saas_request/override_implementations/appsflyer_request_overrides.py b/src/fides/api/service/saas_request/override_implementations/appsflyer_request_overrides.py index 88fa2d7b1b..90f6cf2c9b 100644 --- a/src/fides/api/service/saas_request/override_implementations/appsflyer_request_overrides.py +++ b/src/fides/api/service/saas_request/override_implementations/appsflyer_request_overrides.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.policy import Policy from fides.api.models.privacy_request import PrivacyRequest from fides.api.service.connectors.saas.authenticated_client import AuthenticatedClient @@ -14,7 +14,7 @@ @register("appsflyer_user_read", [SaaSRequestType.READ]) def appsflyer_user_read( client: AuthenticatedClient, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, input_data: Dict[str, List[Any]], diff --git a/src/fides/api/service/saas_request/override_implementations/firebase_auth_request_overrides.py b/src/fides/api/service/saas_request/override_implementations/firebase_auth_request_overrides.py index 80801fb39f..b204270da7 100644 --- a/src/fides/api/service/saas_request/override_implementations/firebase_auth_request_overrides.py +++ b/src/fides/api/service/saas_request/override_implementations/firebase_auth_request_overrides.py @@ -6,7 +6,7 @@ from loguru import logger from fides.api.common_exceptions import FidesopsException -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.policy import Policy from fides.api.models.privacy_request import PrivacyRequest from fides.api.service.connectors.saas.authenticated_client import AuthenticatedClient @@ -22,7 +22,7 @@ @register("firebase_auth_user_access", [SaaSRequestType.READ]) def firebase_auth_user_access( # pylint: disable=R0914 client: AuthenticatedClient, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, input_data: Dict[str, List[Any]], diff --git a/src/fides/api/service/saas_request/override_implementations/iterate_request_overrides.py b/src/fides/api/service/saas_request/override_implementations/iterate_request_overrides.py index 9174d2eb95..7a862d8064 100644 --- a/src/fides/api/service/saas_request/override_implementations/iterate_request_overrides.py +++ b/src/fides/api/service/saas_request/override_implementations/iterate_request_overrides.py @@ -2,7 +2,7 @@ import pydash -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.policy import Policy from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.saas.shared_schemas import HTTPMethod, SaaSRequestParams @@ -17,7 +17,7 @@ @register("iterate_company_read", [SaaSRequestType.READ]) def iterate_company_read( client: AuthenticatedClient, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, input_data: Dict[str, List[Any]], diff --git a/src/fides/api/service/saas_request/override_implementations/mailchimp_request_overrides.py b/src/fides/api/service/saas_request/override_implementations/mailchimp_request_overrides.py index 9149096e5f..e279a3adbc 100644 --- a/src/fides/api/service/saas_request/override_implementations/mailchimp_request_overrides.py +++ b/src/fides/api/service/saas_request/override_implementations/mailchimp_request_overrides.py @@ -3,7 +3,7 @@ import pydash -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.policy import Policy from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.saas.shared_schemas import HTTPMethod, SaaSRequestParams @@ -18,7 +18,7 @@ @register("mailchimp_messages_access", [SaaSRequestType.READ]) def mailchimp_messages_access( client: AuthenticatedClient, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, input_data: Dict[str, List[Any]], diff --git a/src/fides/api/service/saas_request/override_implementations/oracle_responsys_request_overrides.py b/src/fides/api/service/saas_request/override_implementations/oracle_responsys_request_overrides.py index bc64cafc3d..4bab4bbd89 100644 --- a/src/fides/api/service/saas_request/override_implementations/oracle_responsys_request_overrides.py +++ b/src/fides/api/service/saas_request/override_implementations/oracle_responsys_request_overrides.py @@ -4,7 +4,7 @@ import pydash from fides.api.common_exceptions import FidesopsException -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.policy import Policy from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.saas.shared_schemas import HTTPMethod, SaaSRequestParams @@ -20,7 +20,7 @@ @register("oracle_responsys_profile_list_recipients_read", [SaaSRequestType.READ]) def oracle_responsys_profile_list_recipients_read( client: AuthenticatedClient, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, input_data: Dict[str, List[Any]], diff --git a/src/fides/api/service/saas_request/override_implementations/statsig_enterprise_request_overrides.py b/src/fides/api/service/saas_request/override_implementations/statsig_enterprise_request_overrides.py index 22d8943333..f0fe6a23ff 100644 --- a/src/fides/api/service/saas_request/override_implementations/statsig_enterprise_request_overrides.py +++ b/src/fides/api/service/saas_request/override_implementations/statsig_enterprise_request_overrides.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.policy import Policy from fides.api.models.privacy_request import PrivacyRequest from fides.api.service.connectors.saas.authenticated_client import AuthenticatedClient @@ -14,7 +14,7 @@ @register("statsig_enterprise_user_read", [SaaSRequestType.READ]) def statsig_enterprise_user_read( client: AuthenticatedClient, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, input_data: Dict[str, List[Any]], diff --git a/src/fides/api/service/storage/storage_uploader_service.py b/src/fides/api/service/storage/storage_uploader_service.py index a94143add1..320a267e04 100644 --- a/src/fides/api/service/storage/storage_uploader_service.py +++ b/src/fides/api/service/storage/storage_uploader_service.py @@ -98,7 +98,15 @@ def _s3_uploader( auth_method = config.details[StorageDetails.AUTH_METHOD.value] return upload_to_s3( - config.secrets, data, bucket_name, file_key, config.format.value, privacy_request, auth_method, data_category_field_mapping, data_use_map # type: ignore + config.secrets, # type: ignore + data, + bucket_name, + file_key, + config.format.value, # type: ignore + privacy_request, + auth_method, + data_category_field_mapping, + data_use_map, ) diff --git a/src/fides/api/task/create_request_tasks.py b/src/fides/api/task/create_request_tasks.py new file mode 100644 index 0000000000..094398c74c --- /dev/null +++ b/src/fides/api/task/create_request_tasks.py @@ -0,0 +1,572 @@ +# pylint: disable=too-many-lines +import json +from typing import Any, Dict, List, Optional, Set + +import networkx +from loguru import logger +from networkx import NetworkXNoCycle +from sqlalchemy.orm import Query, Session + +from fides.api.common_exceptions import TraversalError +from fides.api.graph.config import ( + ROOT_COLLECTION_ADDRESS, + TERMINATOR_ADDRESS, + CollectionAddress, + FieldAddress, +) +from fides.api.graph.graph import DatasetGraph +from fides.api.graph.traversal import ARTIFICIAL_NODES, Traversal, TraversalNode +from fides.api.models.connectionconfig import ConnectionConfig +from fides.api.models.policy import Policy +from fides.api.models.privacy_request import ( + COMPLETED_EXECUTION_LOG_STATUSES, + ExecutionLogStatus, + PrivacyRequest, + RequestTask, +) +from fides.api.schemas.policy import ActionType +from fides.api.task.deprecated_graph_task import format_data_use_map_for_caching +from fides.api.task.execute_request_tasks import log_task_queued, queue_request_task +from fides.api.util.cache import CustomJSONEncoder + + +def _add_edge_if_no_nodes( + traversal_nodes: Dict[CollectionAddress, TraversalNode], + networkx_graph: networkx.DiGraph, +) -> None: + """ + Adds an edge from the root node to the terminator node, altering the networkx_graph in-place. + + Handles edge case if there are no traversal nodes in the graph at all + """ + if not traversal_nodes.items(): + networkx_graph.add_edge(ROOT_COLLECTION_ADDRESS, TERMINATOR_ADDRESS) + + +def build_access_networkx_digraph( + traversal_nodes: Dict[CollectionAddress, TraversalNode], + end_nodes: List[CollectionAddress], + traversal: Traversal, +) -> networkx.DiGraph: + """ + DSR 3.0: Builds an access networkx graph to get consistent formatting of nodes to build the Request Tasks, + regardless of whether node is real or artificial. + + Primarily though, this lets us use networkx.descendants to calculate every node that can be reached from the current + node to more easily mark downstream nodes as failed if the current node fails. + """ + networkx_graph = networkx.DiGraph() + networkx_graph.add_nodes_from(traversal_nodes.keys()) + networkx_graph.add_nodes_from(ARTIFICIAL_NODES) + + # The first nodes visited are the nodes that only need identity data. + # Therefore, they are all immediately downstream of the root. + first_nodes: Dict[FieldAddress, str] = traversal.extract_seed_field_addresses() + + for node in [ + CollectionAddress(initial_node.dataset, initial_node.collection) + for initial_node in first_nodes + ]: + networkx_graph.add_edge(ROOT_COLLECTION_ADDRESS, node) + + for collection_address, traversal_node in traversal_nodes.items(): + for child in traversal_node.children: + # For every node, add a downstream edge to its children + # that were calculated in traversal.traverse + networkx_graph.add_edge(collection_address, child) + + for node in end_nodes: + # Connect the end nodes, those that have no downstream dependencies, to the terminator node + networkx_graph.add_edge(node, TERMINATOR_ADDRESS) + + _add_edge_if_no_nodes(traversal_nodes, networkx_graph) + return networkx_graph + + +def _evaluate_erasure_dependencies( + traversal_node: TraversalNode, end_nodes: List[CollectionAddress] +) -> Set[CollectionAddress]: + """ + Return a set of collection addresses corresponding to collections that need + to be erased before the given task. + + Remove the dependent collection addresses + from `end_nodes` so they can be executed in the correct order. If a task does + not have any dependencies it is linked directly to the root node + """ + erase_after = traversal_node.node.collection.erase_after + for collection in erase_after: + if collection in end_nodes: + # end_node list is modified in place + end_nodes.remove(collection) + # this task will execute after the collections in `erase_after` or + # execute at the beginning by linking it to the root node + if len(erase_after): + erase_after.add(ROOT_COLLECTION_ADDRESS) + return erase_after if len(erase_after) else {ROOT_COLLECTION_ADDRESS} + + +def build_erasure_networkx_digraph( + traversal_nodes: Dict[CollectionAddress, TraversalNode], + end_nodes: List[CollectionAddress], +) -> networkx.DiGraph: + """ + DSR 3.0: Builds a networkx graph of erasure nodes to get consistent formatting of nodes to build the Request Tasks, + regardless of whether node is real or artificial. + + Erasure graphs are different from access graphs, in that we've queried all the data we need upfront in the access + graphs, so that all nodes can in theory run entirely in parallel, except for the "erase_after" dependencies. + + We tack on the "erase_after" dependencies here that aren't captured in traversal.traverse. + + """ + networkx_graph = networkx.DiGraph() + networkx_graph.add_nodes_from(traversal_nodes.keys()) + networkx_graph.add_nodes_from(ARTIFICIAL_NODES) + + for node_name, traversal_node in traversal_nodes.items(): + # Add an edge from the root node to the current node, unless explicit erasure + # dependencies are defined. Modifies end_nodes in place + erasure_dependencies: Set[CollectionAddress] = _evaluate_erasure_dependencies( + traversal_node, end_nodes + ) + for dep in erasure_dependencies: + networkx_graph.add_edge(dep, node_name) + + for node in end_nodes: + # Connect each end node without downstream dependencies to the terminator node + networkx_graph.add_edge(node, TERMINATOR_ADDRESS) + + try: + # Run extra checks on the graph since we potentially modified traversal_nodes + networkx.find_cycle(networkx_graph, ROOT_COLLECTION_ADDRESS) + except NetworkXNoCycle: + logger.info("No cycles found as expected") + else: + raise TraversalError( + "The values for the `erase_after` fields created a cycle in the DAG." + ) + + _add_edge_if_no_nodes(traversal_nodes, networkx_graph) + return networkx_graph + + +def build_consent_networkx_digraph( + traversal_nodes: Dict[CollectionAddress, TraversalNode], +) -> networkx.DiGraph: + """ + DSR 3.0: Builds a networkx graph of consent nodes to get consistent formatting of nodes to build the Request Tasks, + regardless of whether node is real or artificial. + """ + networkx_graph = networkx.DiGraph() + networkx_graph.add_nodes_from(traversal_nodes.keys()) + networkx_graph.add_nodes_from([TERMINATOR_ADDRESS, ROOT_COLLECTION_ADDRESS]) + + for collection_address, _ in traversal_nodes.items(): + # Consent graphs are simple. One node for every dataset (which has a mocked collection) + # and no dependencies between nodes. + networkx_graph.add_edge(ROOT_COLLECTION_ADDRESS, collection_address) + networkx_graph.add_edge(collection_address, TERMINATOR_ADDRESS) + + _add_edge_if_no_nodes(traversal_nodes, networkx_graph) + return networkx_graph + + +def base_task_data( + graph: networkx.DiGraph, + dataset_graph: DatasetGraph, + privacy_request: PrivacyRequest, + node: CollectionAddress, + traversal_nodes: Dict[CollectionAddress, TraversalNode], +) -> Dict: + """Build a dictionary of common RequestTask attributes that are shared for building + access, consent, and erasure tasks""" + collection_representation: Optional[Dict] = None + traversal_details = {} + + if node not in ARTIFICIAL_NODES: + # Save a representation of the collection that can be re-hydrated later + # when executing the node, so we don't have to recalculate incoming + # and outgoing edges. + collection_representation = json.loads( + dataset_graph.nodes[node].collection.json() + ) + # Saves traversal details based on data dependencies like incoming edges + # and input keys, also useful for building the Execution Node + traversal_details = traversal_nodes[node].format_traversal_details_for_save() + + return { + "privacy_request_id": privacy_request.id, + "upstream_tasks": sorted( + [upstream.value for upstream in graph.predecessors(node)] + ), + "downstream_tasks": sorted( + [downstream.value for downstream in graph.successors(node)] + ), + "all_descendant_tasks": sorted( + [descend.value for descend in list(networkx.descendants(graph, node))] + ), + "collection_address": node.value, + "dataset_name": node.dataset, + "collection_name": node.collection, + "status": ExecutionLogStatus.complete + if node == ROOT_COLLECTION_ADDRESS + else ExecutionLogStatus.pending, + "collection": collection_representation, + "traversal_details": traversal_details, + } + + +def persist_new_access_request_tasks( + session: Session, + privacy_request: PrivacyRequest, + traversal: Traversal, + traversal_nodes: Dict[CollectionAddress, TraversalNode], + end_nodes: List[CollectionAddress], + dataset_graph: DatasetGraph, +) -> List[RequestTask]: + """ + Create individual access RequestTasks from the TraversalNodes and persist to the database. + This should only run the first time a privacy request runs. + """ + logger.info( + "Creating access request tasks for privacy request {}.", privacy_request.id + ) + graph: networkx.DiGraph = build_access_networkx_digraph( + traversal_nodes, end_nodes, traversal + ) + + for node in list(networkx.topological_sort(graph)): + if privacy_request.get_existing_request_task( + session, action_type=ActionType.access, collection_address=node + ): + continue + + RequestTask.create( + session, + data={ + **base_task_data( + graph, dataset_graph, privacy_request, node, traversal_nodes + ), + "access_data": json.dumps([traversal.seed_data], cls=CustomJSONEncoder) + if node == ROOT_COLLECTION_ADDRESS + else [], # For consistent treatment of nodes, add the seed data to the root node. Subsequent + # tasks will save the data collected on the same field. + "action_type": ActionType.access, + }, + ) + + root_task: RequestTask = privacy_request.get_root_task_by_action(ActionType.access) + + return [root_task] + + +def persist_initial_erasure_request_tasks( + session: Session, + privacy_request: PrivacyRequest, + traversal_nodes: Dict[CollectionAddress, TraversalNode], + end_nodes: List[CollectionAddress], + dataset_graph: DatasetGraph, +) -> List[RequestTask]: + """ + Create starter individual erasure RequestTasks from the TraversalNodes and persist to the database. + + These are not ready to run yet as they are still waiting for access data from the access graph + to be able to build masking requests + """ + logger.info( + "Creating initial erasure request tasks for privacy request {}.", + privacy_request.id, + ) + graph: networkx.DiGraph = build_erasure_networkx_digraph(traversal_nodes, end_nodes) + + for node in list(networkx.topological_sort(graph)): + if privacy_request.get_existing_request_task( + session, action_type=ActionType.erasure, collection_address=node + ): + continue + + RequestTask.create( + session, + data={ + **base_task_data( + graph, dataset_graph, privacy_request, node, traversal_nodes + ), + "action_type": ActionType.erasure, + }, + ) + + # If a policy has an erasure rule, this method is run immediately after creating the access tasks, so their + # nodes in the database are the same. There are no "ready" tasks yet, because we need to wait for the + # access step to run, so we return an empty list here. + return [] + + +def _get_data_for_erasures( + session: Session, privacy_request: PrivacyRequest, request_task: RequestTask +) -> List[Dict]: + """ + Return the access data in erasure format needed to format the masking request for the current node. + """ + # Get the access task of the same name as the erasure task so we can transfer the data + # collected for masking onto the current erasure task + corresponding_access_task: Optional[ + RequestTask + ] = privacy_request.get_existing_request_task( + db=session, + action_type=ActionType.access, + collection_address=request_task.request_task_address, + ) + retrieved_task_data: List[Dict] = [] + if ( + corresponding_access_task + and request_task.request_task_address not in ARTIFICIAL_NODES + ): + # IMPORTANT. Use "data_for_erasures" - not RequestTask.access_data. + # For arrays, "access_data" may remove non-matched elements from arrays, but to build erasure + # queries we need the original data in the appropriate indices + retrieved_task_data = corresponding_access_task.get_decoded_data_for_erasures() + + return retrieved_task_data + + +def update_erasure_tasks_with_access_data( + session: Session, + privacy_request: PrivacyRequest, +) -> None: + """ + Update individual erasure RequestTasks with data from the TraversalNodes and persist to the database. + """ + logger.info( + "Updating erasure request tasks with data needed for masking requests {}.", + privacy_request.id, + ) + + for request_task in privacy_request.erasure_tasks: + # I pull access data saved in the format suitable for erasures + # off of the access nodes to be saved onto the erasure nodes. + retrieved_task_data = _get_data_for_erasures( + session, privacy_request, request_task + ) + request_task.data_for_erasures = json.dumps( + retrieved_task_data, cls=CustomJSONEncoder + ) + request_task.save(session) + + +def persist_new_consent_request_tasks( + session: Session, + privacy_request: PrivacyRequest, + traversal_nodes: Dict[CollectionAddress, TraversalNode], + identity: Dict[str, Any], + dataset_graph: DatasetGraph, +) -> List[RequestTask]: + """ + Create individual erasure RequestTasks from the TraversalNodes and persist to the database. This should only + run the first time a privacy request runs. + + Consent propagation graphs are much simpler with no relationships between nodes. Every node has identity data input, + and every node outputs whether the consent request succeeded. + """ + graph: networkx.DiGraph = build_consent_networkx_digraph(traversal_nodes) + + for node in list(networkx.topological_sort(graph)): + if privacy_request.get_existing_request_task( + session, action_type=ActionType.consent, collection_address=node + ): + continue + RequestTask.create( + session, + data={ + **base_task_data( + graph, dataset_graph, privacy_request, node, traversal_nodes + ), + # Consent nodes take in identity data from their upstream root node + "access_data": json.dumps([identity], cls=CustomJSONEncoder) + if node == ROOT_COLLECTION_ADDRESS + else [], + "action_type": ActionType.consent, + }, + ) + + root_task: RequestTask = privacy_request.get_root_task_by_action(ActionType.consent) + + return [root_task] + + +def collect_tasks_fn( + tn: TraversalNode, data: Dict[CollectionAddress, TraversalNode] +) -> None: + """ + A function that is passed to traversal.traverse() that returns the modified + traversal node with its parents and children linked as an action. + """ + if not tn.is_root_node(): + data[tn.address] = tn + + +def run_access_request( + privacy_request: PrivacyRequest, + policy: Policy, + graph: DatasetGraph, + connection_configs: List[ConnectionConfig], + identity: Dict[str, Any], + session: Session, + privacy_request_proceed: bool = True, +) -> List[RequestTask]: + """ + DSR 3.0: Build the "access" graph, add its tasks to the database and queue the root task. If erasure rules + are present, build the "erasure" graph at the same time so their nodes match, but these erasure nodes are + not yet ready to run until the access graph is complete in-full. + + If we are *reprocessing* a Privacy Request, instead queue tasks whose upstream nodes are complete. + """ + + if privacy_request.access_tasks.count(): + # If we are reprocessing a privacy request, just see if there + # are existing ready tasks; don't create new ones. + # Possible edge cases here where we have no ready tasks and + # Privacy Request is hanging in an in-processing state. + ready_tasks: List[RequestTask] = get_existing_ready_tasks( + session, privacy_request, ActionType.access + ) + else: + logger.info("Building access graph for {}", privacy_request.id) + traversal: Traversal = Traversal(graph, identity) + + # Traversal.traverse populates traversal_nodes in place, adding parents and children to each traversal_node. + traversal_nodes: Dict[CollectionAddress, TraversalNode] = {} + end_nodes: List[CollectionAddress] = traversal.traverse( + traversal_nodes, collect_tasks_fn + ) + # Save Access Request Tasks to the database + ready_tasks = persist_new_access_request_tasks( + session, privacy_request, traversal, traversal_nodes, end_nodes, graph + ) + + if ( + policy.get_rules_for_action(action_type=ActionType.erasure) + and not privacy_request.erasure_tasks.count() + ): + # If applicable, go ahead and save Erasure Request Tasks to the Database. + # These erasure tasks aren't ready to run until the access graph is completed + # in full, but this makes sure the nodes in the graphs match. + erasure_end_nodes: List[CollectionAddress] = list(graph.nodes.keys()) + persist_initial_erasure_request_tasks( + session, privacy_request, traversal_nodes, erasure_end_nodes, graph + ) + + # cache a map of collections -> data uses for the output package of access requests + privacy_request.cache_data_use_map( + format_data_use_map_for_caching( + { + coll_address: tn.node.dataset.connection_key + for (coll_address, tn) in traversal_nodes.items() + }, + connection_configs, + ) + ) + + for task in ready_tasks: + log_task_queued(task, "main runner") + queue_request_task(task, privacy_request_proceed) + + return ready_tasks + + +def run_erasure_request( # pylint: disable = too-many-arguments + privacy_request: PrivacyRequest, + session: Session, + privacy_request_proceed: bool = True, +) -> List[RequestTask]: + """ + DSR 3.0: Update erasure Request Tasks that were built in the "run_access_request" step with data + collected to build masking requests and queue the root task for processing. + + If we are reprocessing a Privacy Request, instead queue tasks whose upstream nodes are complete. + """ + update_erasure_tasks_with_access_data(session, privacy_request) + ready_tasks: List[RequestTask] = ( + get_existing_ready_tasks(session, privacy_request, ActionType.erasure) or [] + ) + + for task in ready_tasks: + log_task_queued(task, "main runner") + queue_request_task(task, privacy_request_proceed) + return ready_tasks + + +def run_consent_request( # pylint: disable = too-many-arguments + privacy_request: PrivacyRequest, + graph: DatasetGraph, + identity: Dict[str, Any], + session: Session, + privacy_request_proceed: bool = True, +) -> List[RequestTask]: + """ + DSR 3.0: Build the "consent" graph, add its tasks to the database and queue the root task. + + If we are reprocessing a Privacy Request, instead queue tasks whose upstream nodes are complete. + + The graph built is very simple: there are no relationships between the nodes, every node has + identity data input and every node outputs whether the consent request succeeded. + + The DatasetGraph passed in is expected to have one Node per Dataset. That Node is expected to carry out requests + for the Dataset as a whole. + """ + + if privacy_request.consent_tasks.count(): + ready_tasks: List[RequestTask] = get_existing_ready_tasks( + session, privacy_request, ActionType.consent + ) + else: + logger.info("Building consent graph for {}", privacy_request.id) + traversal_nodes: Dict[CollectionAddress, TraversalNode] = {} + # Unlike erasure and access graphs, we don't call traversal.traverse, but build a simpler + # graph that just has one node per dataset + for col_address, node in graph.nodes.items(): + traversal_node = TraversalNode(node) + traversal_nodes[col_address] = traversal_node + + ready_tasks = persist_new_consent_request_tasks( + session, privacy_request, traversal_nodes, identity, graph + ) + + for task in ready_tasks: + log_task_queued(task, "main runner") + queue_request_task(task, privacy_request_proceed) + return ready_tasks + + +def get_existing_ready_tasks( + session: Session, privacy_request: PrivacyRequest, action_type: ActionType +) -> List[RequestTask]: + """ + Return existing RequestTasks if applicable in the event of reprocessing instead + of creating new ones + """ + ready: List[RequestTask] = [] + request_tasks: Query = privacy_request.get_tasks_by_action(action_type) + if request_tasks.count(): + incomplete_tasks: Query = request_tasks.filter( + RequestTask.status.notin_(COMPLETED_EXECUTION_LOG_STATUSES) + ) + + for task in incomplete_tasks: + # Checks if both upstream tasks are complete and the task is not currently in-flight (if using workers) + if task.can_queue_request_task(session, should_log=True): + task.update_status(session, ExecutionLogStatus.pending) + ready.append(task) + elif task.status == ExecutionLogStatus.error: + # Important to reset errored status to pending so it can be rerun + task.update_status(session, ExecutionLogStatus.pending) + + if ready: + logger.info( + "Found existing {} task(s) ready to reprocess: {}. Privacy Request: {}", + action_type.value, + [t.collection_address for t in ready], + privacy_request.id, + ) + return ready + return ready diff --git a/src/fides/api/task/deprecated_graph_task.py b/src/fides/api/task/deprecated_graph_task.py new file mode 100644 index 0000000000..f46f549d3e --- /dev/null +++ b/src/fides/api/task/deprecated_graph_task.py @@ -0,0 +1,313 @@ +# pylint: disable=too-many-lines +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import dask +from dask import delayed # type: ignore[attr-defined] +from dask.core import getcycle +from dask.threaded import get +from sqlalchemy.orm import Session + +from fides.api.common_exceptions import TraversalError +from fides.api.graph.config import ( + ROOT_COLLECTION_ADDRESS, + TERMINATOR_ADDRESS, + CollectionAddress, +) +from fides.api.graph.graph import DatasetGraph +from fides.api.graph.traversal import Traversal, TraversalNode +from fides.api.models.connectionconfig import ConnectionConfig +from fides.api.models.policy import Policy +from fides.api.models.privacy_request import PrivacyRequest +from fides.api.models.sql_models import System # type: ignore[attr-defined] +from fides.api.task.graph_task import EMPTY_REQUEST_TASK, GraphTask +from fides.api.task.task_resources import TaskResources +from fides.api.util.collection_util import Row + +# These are deprecated DSR 2.0 functions that support running DSR's in sequence with Dask in-memory +# Supported for a limited time. + +dask.config.set(scheduler="threads") + + +def update_mapping_from_cache( + dsk: Dict[CollectionAddress, Tuple[Any, ...]], + resources: TaskResources, + start_fn: Callable, +) -> None: + """When resuming a privacy request from a paused or failed state, update the `dsk` dictionary with results we've + already obtained from a previous run. Remove upstream dependencies for these nodes, and just return the data we've + already retrieved, rather than visiting them again. + + If there's no cached data, the dsk dictionary won't change. + """ + + cached_results: Dict[str, Optional[List[Row]]] = resources.get_all_cached_objects() + + for collection_name in cached_results: + dsk[CollectionAddress.from_string(collection_name)] = ( + start_fn(cached_results[collection_name]), + ) + + +def format_data_use_map_for_caching( + connection_key_mapping: Dict[CollectionAddress, str], + connection_configs: List[ConnectionConfig], +) -> Dict[str, Set[str]]: + """ + Create a map of `Collection`s mapped to their associated `DataUse`s + to be stored in the cache. This is done before request execution, so that we + maintain the _original_ state of the graph as it's used for request execution. + The graph is subject to change "from underneath" the request execution runtime, + but we want to avoid picking up those changes in our data use map. + + `DataUse`s are associated with a `Collection` by means of the `System` + that's linked to a `Collection`'s `Connection` definition. + + Example: + { + : {"data_use_1", "data_use_2"}, + : {"data_use_1"}, + } + """ + resp: Dict[str, Set[str]] = {} + connection_config_mapping: Dict[str, ConnectionConfig] = { + connection_config.key: connection_config + for connection_config in connection_configs + } + for collection_addr, connection_key in connection_key_mapping.items(): + connection_config = connection_config_mapping.get(connection_key, None) + if not connection_config or not connection_config.system: + resp[collection_addr.value] = set() + continue + data_uses: Set[str] = System.get_data_uses( + [connection_config.system], include_parents=False + ) + resp[collection_addr.value] = data_uses + + return resp + + +def start_function(seed: List[Dict[str, Any]]) -> Callable[[], List[Dict[str, Any]]]: + """Return a function for collections with no upstream dependencies, that just start + with seed data. + + This is used for root nodes or previously-visited nodes on restart.""" + + def g() -> List[Dict[str, Any]]: + return seed + + return g + + +def run_access_request_deprecated( + privacy_request: PrivacyRequest, + policy: Policy, + graph: DatasetGraph, + connection_configs: List[ConnectionConfig], + identity: Dict[str, Any], + session: Session, +) -> Dict[str, List[Row]]: + """Deprecated: Run the access request sequentially in-memory using Dask""" + traversal: Traversal = Traversal(graph, identity) + with TaskResources( + privacy_request, policy, connection_configs, EMPTY_REQUEST_TASK, session + ) as resources: + + def collect_tasks_fn( + tn: TraversalNode, data: Dict[CollectionAddress, GraphTask] + ) -> None: + """Run the traversal, as an action creating a GraphTask for each traversal_node.""" + if not tn.is_root_node(): + # Mock a RequestTask object in memory to share code with DSR 3.0 + resources.privacy_request_task = tn.to_mock_request_task() + data[tn.address] = GraphTask(resources) + + def termination_fn( + *dependent_values: List[Row], + ) -> Dict[str, Optional[List[Row]]]: + """A termination function that just returns its inputs mapped to their source addresses. + This needs to wait for all dependent keys because this is how dask is informed to wait for + all terminating addresses before calling this.""" + + return resources.get_all_cached_objects() + + env: Dict[CollectionAddress, GraphTask] = {} + end_nodes: List[CollectionAddress] = traversal.traverse(env, collect_tasks_fn) + + dsk: Dict[CollectionAddress, Tuple[Any, ...]] = { + k: (t.access_request, *t.execution_node.input_keys) for k, t in env.items() + } + dsk[ROOT_COLLECTION_ADDRESS] = (start_function([traversal.seed_data]),) + dsk[TERMINATOR_ADDRESS] = (termination_fn, *end_nodes) + update_mapping_from_cache(dsk, resources, start_function) + + # cache a map of collections -> data uses for the output package of access requests + # this is cached here before request execution, since this is the state of the + # graph used for request execution. the graph could change _during_ request execution, + # but we don't want those changes in our data use map. + privacy_request.cache_data_use_map( + format_data_use_map_for_caching( + { + coll_address: gt.execution_node.connection_key + for (coll_address, gt) in env.items() + }, + connection_configs, + ) + ) + + v = delayed(get(dsk, TERMINATOR_ADDRESS, num_workers=1)) + return v.compute() + + +def update_erasure_mapping_from_cache( + dsk: Dict[CollectionAddress, Union[Tuple[Any, ...], int]], resources: TaskResources +) -> None: + """On pause or restart from failure, update the dsk graph to skip running erasures on collections + we've already visited. Instead, just return the previous count of rows affected. + + If there's no cached data, the dsk dictionary won't change. + """ + cached_erasures: Dict[str, int] = resources.get_all_cached_erasures() + + for collection_name in cached_erasures: + dsk[CollectionAddress.from_string(collection_name)] = cached_erasures[ + collection_name + ] + + +def run_erasure_request_deprecated( # pylint: disable = too-many-arguments + privacy_request: PrivacyRequest, + policy: Policy, + graph: DatasetGraph, + connection_configs: List[ConnectionConfig], + identity: Dict[str, Any], + access_request_data: Dict[str, List[Row]], + session: Session, +) -> Dict[str, int]: + """Deprecated: Run an erasure request sequentially in-memory using Dask""" + traversal: Traversal = Traversal(graph, identity) + with TaskResources( + privacy_request, policy, connection_configs, EMPTY_REQUEST_TASK, session + ) as resources: + + def collect_tasks_fn( + tn: TraversalNode, data: Dict[CollectionAddress, GraphTask] + ) -> None: + """Run the traversal, as an action creating a GraphTask for each traversal_node.""" + if not tn.is_root_node(): + # Mock a RequestTask object in memory to share code with DSR 3.0 + resources.privacy_request_task = tn.to_mock_request_task() + data[tn.address] = GraphTask(resources) + + env: Dict[CollectionAddress, GraphTask] = {} + # Modifies env in place + traversal.traverse(env, collect_tasks_fn) + erasure_end_nodes = list(graph.nodes.keys()) + + def termination_fn(*dependent_values: int) -> Dict[str, int]: + """ + The erasure order can be affected in a way that not every node is directly linked + to the termination node. This means that we can't just aggregate the inputs directly, + we must read the erasure results from the cache. + """ + return resources.get_all_cached_erasures() + + access_request_data[ROOT_COLLECTION_ADDRESS.value] = [identity] + + dsk: Dict[CollectionAddress, Any] = { + k: ( + t.erasure_request, + access_request_data.get( + str(k), [] + ), # Pass in the results of the access request for this collection + *_evaluate_erasure_dependencies(t, erasure_end_nodes), + ) + for k, t in env.items() + } + + # root node returns 0 to be consistent with the output of the other erasure tasks + dsk[ROOT_COLLECTION_ADDRESS] = 0 + # terminator function reads and returns the cached erasure results for the entire erasure traversal + dsk[TERMINATOR_ADDRESS] = (termination_fn, *erasure_end_nodes) + update_erasure_mapping_from_cache(dsk, resources) + + # using an existing function from dask.core to detect cycles in the generated graph + collection_cycle = getcycle(dsk, None) + if collection_cycle: + raise TraversalError( + f"The values for the `erase_after` fields caused a cycle in the following collections {collection_cycle}" + ) + + v = delayed(get(dsk, TERMINATOR_ADDRESS, num_workers=1)) + return v.compute() + + +def _evaluate_erasure_dependencies( + t: GraphTask, end_nodes: List[CollectionAddress] +) -> Set[CollectionAddress]: + """ + Return a set of collection addresses corresponding to collections that need + to be erased before the given task. Remove the dependent collection addresses + from `end_nodes` so they can be executed in the correct order. If a task does + not have any dependencies it is linked directly to the root node + """ + erase_after = t.execution_node.collection.erase_after + for collection in erase_after: + if collection in end_nodes: + # end_node list is modified in place + end_nodes.remove(collection) + # this task will execute after the collections in `erase_after` or + # execute at the beginning by linking it to the root node + return erase_after if len(erase_after) else {ROOT_COLLECTION_ADDRESS} + + +def run_consent_request_deprecated( # pylint: disable = too-many-arguments + privacy_request: PrivacyRequest, + policy: Policy, + graph: DatasetGraph, + connection_configs: List[ConnectionConfig], + identity: Dict[str, Any], + session: Session, +) -> Dict[str, bool]: + """Run a consent request + + The graph built is very simple: there are no relationships between the nodes, every node has + identity data input and every node outputs whether the consent request succeeded. + + The DatasetGraph passed in is expected to have one Node per Dataset. That Node is expected to carry out requests + for the Dataset as a whole. + """ + with TaskResources( + privacy_request, policy, connection_configs, EMPTY_REQUEST_TASK, session + ) as resources: + graph_keys: List[CollectionAddress] = list(graph.nodes.keys()) + dsk: Dict[CollectionAddress, Any] = {} + + for col_address, node in graph.nodes.items(): + traversal_node = TraversalNode(node) + # Mock a RequestTask object in memory to share code with DSR 3.0 + resources.privacy_request_task = traversal_node.to_mock_request_task() + task = GraphTask(resources) + dsk[col_address] = (task.consent_request, identity) + + def termination_fn(*dependent_values: bool) -> Tuple[bool, ...]: + """The dependent_values here is an bool output from each task feeding in, where + each task reports the output of 'task.consent_request(identity_data)', which is whether the + consent request succeeded + + The termination function just returns this tuple of booleans.""" + return dependent_values + + # terminator function waits for all keys + dsk[TERMINATOR_ADDRESS] = (termination_fn, *graph_keys) + + v = delayed(get(dsk, TERMINATOR_ADDRESS, num_workers=1)) + + update_successes: Tuple[bool, ...] = v.compute() + # we combine the output of the termination function with the input keys to provide + # a map of {collection_name: whether consent request succeeded}: + consent_update_map: Dict[str, bool] = dict( + zip([coll.value for coll in graph_keys], update_successes) + ) + + return consent_update_map diff --git a/src/fides/api/task/execute_request_tasks.py b/src/fides/api/task/execute_request_tasks.py new file mode 100644 index 0000000000..04481f9542 --- /dev/null +++ b/src/fides/api/task/execute_request_tasks.py @@ -0,0 +1,428 @@ +from typing import Callable, List, Optional, Tuple + +from celery.app.task import Task +from loguru import logger +from sqlalchemy.orm import Query, Session + +from fides.api.common_exceptions import ( + PrivacyRequestCanceled, + PrivacyRequestNotFound, + RequestTaskNotFound, + ResumeTaskException, + UpstreamTasksNotReady, +) +from fides.api.graph.config import TERMINATOR_ADDRESS, CollectionAddress +from fides.api.models.connectionconfig import ConnectionConfig +from fides.api.models.policy import CurrentStep +from fides.api.models.privacy_request import ( + ExecutionLog, + ExecutionLogStatus, + PrivacyRequest, + PrivacyRequestStatus, + RequestTask, +) +from fides.api.schemas.policy import ActionType +from fides.api.task.graph_task import ( + GraphTask, + mark_current_and_downstream_nodes_as_failed, +) +from fides.api.task.task_resources import TaskResources +from fides.api.tasks import DatabaseTask, celery_app +from fides.api.util.cache import cache_task_tracking_key +from fides.api.util.collection_util import Row + +# DSR 3.0 task functions + + +def run_prerequisite_task_checks( + session: Session, privacy_request_id: str, privacy_request_task_id: str +) -> Tuple[PrivacyRequest, RequestTask, Query]: + """ + Upfront checks that run as soon as the RequestTask is executed by the worker. + + Returns resources for use in executing a task + """ + privacy_request: Optional[PrivacyRequest] = PrivacyRequest.get( + db=session, object_id=privacy_request_id + ) + request_task: Optional[RequestTask] = RequestTask.get( + db=session, object_id=privacy_request_task_id + ) + + if not privacy_request: + raise PrivacyRequestNotFound( + f"Privacy request with id {privacy_request_id} not found" + ) + + if privacy_request.status == PrivacyRequestStatus.canceled: + raise PrivacyRequestCanceled( + f"Cannot execute request task {privacy_request_task_id} of privacy request {privacy_request_id}: status is {privacy_request.status.value}" + ) + + if not request_task or not request_task.privacy_request_id == privacy_request.id: + raise RequestTaskNotFound( + f"Request Task with id {privacy_request_task_id} not found for privacy request {privacy_request_id}" + ) + + assert request_task # For mypy + + upstream_results: Query = request_task.upstream_tasks_objects(session) + + # Only bother running this if the current task body needs to run + if request_task.status == ExecutionLogStatus.pending: + # Only running the upstream check instead of RequestTask.can_queue_request_task since + # the node is already queued. + if not request_task.upstream_tasks_complete(session, should_log=False): + raise UpstreamTasksNotReady( + f"Cannot start {request_task.action_type} task {request_task.collection_address}. Privacy Request: {privacy_request.id}, Request Task {request_task.id}. Waiting for upstream tasks to finish." + ) + + return privacy_request, request_task, upstream_results + + +def create_graph_task( + session: Session, request_task: RequestTask, resources: TaskResources +) -> GraphTask: + """Hydrates a GraphTask from the saved collection details on the Request Task in the database + + This could fail if things like our Collection definitions have changed since we created the Task + to begin with - this may be unrecoverable and a new Privacy Request should be created. + """ + try: + graph_task: GraphTask = GraphTask(resources) + + except Exception as exc: + logger.debug( + "Cannot execute task - error loading task from database. Privacy Request: {}, Request Task {}. Exception {}", + request_task.privacy_request_id, + request_task.id, + str(exc), + ) + # Normally the GraphTask takes care of creating the ExecutionLog, but in this case we can't create it in the first place! + ExecutionLog.create( + db=session, + data={ + "connection_key": None, + "dataset_name": request_task.dataset_name, + "collection_name": request_task.collection_name, + "fields_affected": [], + "action_type": request_task.action_type, + "status": ExecutionLogStatus.error, + "privacy_request_id": request_task.privacy_request_id, + "message": str(exc), + }, + ) + mark_current_and_downstream_nodes_as_failed(request_task, session) + + raise ResumeTaskException( + f"Cannot resume request task. Error hydrating task from database: Request Task {request_task.id} for Privacy Request {request_task.privacy_request_id}. {exc}" + ) + + return graph_task + + +def can_run_task_body( + request_task: RequestTask, +) -> bool: + """Return True if we can execute the task body. We should skip if the task is already + complete or this is a root/terminator node""" + if request_task.is_terminator_task: + logger.info( + "Terminator {} task reached. Privacy Request: {}, Request Task {}", + request_task.action_type.value, + request_task.privacy_request_id, + request_task.id, + ) + return False + if request_task.is_root_task: + # Shouldn't be possible but adding as a catch-all + return False + if request_task.status != ExecutionLogStatus.pending: + logger_method(request_task)( + "Skipping {} task {} with status {}. Privacy Request: {}, Request Task {}", + request_task.action_type.value, + request_task.collection_address, + request_task.status.value, + request_task.privacy_request_id, + request_task.id, + ) + return False + + return True + + +def queue_downstream_tasks( + session: Session, + request_task: RequestTask, + privacy_request: PrivacyRequest, + next_step: CurrentStep, + privacy_request_proceed: bool, +) -> None: + """Queue downstream tasks of the current node **if** the downstream task has all its upstream tasks completed. + + If we've reached the terminator task, restart the privacy request from the appropriate checkpoint. + """ + pending_downstream: Query = request_task.get_pending_downstream_tasks(session) + for downstream_task in pending_downstream: + if downstream_task.can_queue_request_task(session, should_log=True): + log_task_queued(downstream_task, request_task.collection_address) + queue_request_task(downstream_task, privacy_request_proceed) + + if ( + request_task.request_task_address == TERMINATOR_ADDRESS + and request_task.status != ExecutionLogStatus.complete + ): + # Only queue privacy request from the next step if we haven't reached the terminator before. + # Multiple pathways could mark the same node as complete, so we may have already reached the + # terminator node through a quicker path. + from fides.api.service.privacy_request.request_runner_service import ( + queue_privacy_request, + ) + + if ( + privacy_request_proceed + ): # For Testing, this could be set to False, so we could just + # run one of the graphs and not the entire privacy request + queue_privacy_request( + privacy_request_id=privacy_request.id, + from_step=next_step.value, + ) + request_task.update_status(session, ExecutionLogStatus.complete) + + +@celery_app.task(base=DatabaseTask, bind=True) +def run_access_node( + self: DatabaseTask, + privacy_request_id: str, + privacy_request_task_id: str, + privacy_request_proceed: bool = True, +) -> None: + """Run an individual task in the access graph for DSR 3.0 and queue downstream nodes + upon completion if applicable""" + with self.get_new_session() as session: + privacy_request, request_task, upstream_results = run_prerequisite_task_checks( + session, privacy_request_id, privacy_request_task_id + ) + log_task_starting(request_task) + + if can_run_task_body(request_task): + # Build GraphTask resource to facilitate execution + with TaskResources( + privacy_request, + privacy_request.policy, + session.query(ConnectionConfig).all(), + request_task, + session, + ) as resources: + graph_task: GraphTask = create_graph_task( + session, request_task, resources + ) + # Currently, upstream tasks and "input keys" (which are built by data dependencies) + # are the same, but they may not be the same in the future. + ordered_upstream_tasks: List[ + Optional[RequestTask] + ] = _order_tasks_by_input_key( + graph_task.execution_node.input_keys, upstream_results + ) + # Pass in access data dependencies in the same order as the input keys. + # If we don't have access data for an upstream node, pass in an empty list + upstream_access_data: List[List[Row]] = [ + upstream.get_decoded_access_data() if upstream else [] + for upstream in ordered_upstream_tasks + ] + # Run the main access function + graph_task.access_request(*upstream_access_data) + log_task_complete(request_task) + + queue_downstream_tasks( + session, + request_task, + privacy_request, + CurrentStep.upload_access, + privacy_request_proceed, + ) + return + + +@celery_app.task(base=DatabaseTask, bind=True) +def run_erasure_node( + self: DatabaseTask, + privacy_request_id: str, + privacy_request_task_id: str, + privacy_request_proceed: bool = True, +) -> None: + """Run an individual task in the erasure graph for DSR 3.0 and queue downstream nodes + upon completion if applicable""" + with self.get_new_session() as session: + privacy_request, request_task, _ = run_prerequisite_task_checks( + session, privacy_request_id, privacy_request_task_id + ) + log_task_starting(request_task) + + if can_run_task_body(request_task): + with TaskResources( + privacy_request, + privacy_request.policy, + session.query(ConnectionConfig).all(), + request_task, + session, + ) as resources: + # Build GraphTask resource to facilitate execution + graph_task: GraphTask = create_graph_task( + session, request_task, resources + ) + # Get access data that was saved in the erasure format that was collected from the + # access task for the same collection. This data is used to build the masking request + retrieved_data: List[Row] = ( + request_task.get_decoded_data_for_erasures() or [] + ) + + # Run the main erasure function! + graph_task.erasure_request(retrieved_data) + + log_task_complete(request_task) + + queue_downstream_tasks( + session, + request_task, + privacy_request, + CurrentStep.finalize_erasure, + privacy_request_proceed, + ) + return + + +@celery_app.task(base=DatabaseTask, bind=True) +def run_consent_node( + self: DatabaseTask, + privacy_request_id: str, + privacy_request_task_id: str, + privacy_request_proceed: bool = True, +) -> None: + """Run an individual task in the consent graph for DSR 3.0 and queue downstream nodes + upon completion if applicable""" + with self.get_new_session() as session: + privacy_request, request_task, upstream_results = run_prerequisite_task_checks( + session, privacy_request_id, privacy_request_task_id + ) + log_task_starting(request_task) + + if can_run_task_body(request_task): + # Build GraphTask resource to facilitate execution + with TaskResources( + privacy_request, + privacy_request.policy, + session.query(ConnectionConfig).all(), + request_task, + session, + ) as resources: + graph_task: GraphTask = create_graph_task( + session, request_task, resources + ) + if upstream_results: + # For consent, expected that there is only one upstream node, the root node, + # and it holds the identity data (stored in a list for consistency with other + # data stored in access_data) + access_data: List = ( + upstream_results[0].get_decoded_access_data() or [] + ) + + graph_task.consent_request(access_data[0] if access_data else {}) + + log_task_complete(request_task) + + queue_downstream_tasks( + session, + request_task, + privacy_request, + CurrentStep.finalize_consent, + privacy_request_proceed, + ) + return + + +def logger_method(request_task: RequestTask) -> Callable: + """Log selected no-op items with debug method and others with info method""" + return ( + logger.debug + if request_task.status == ExecutionLogStatus.complete + else logger.info + ) + + +def log_task_starting(request_task: RequestTask) -> None: + """Convenience method for logging task start""" + logger_method(request_task)( + "Starting '{}' task {} with current status '{}'. Privacy Request: {}, Request Task {}", + request_task.action_type, + request_task.collection_address, + request_task.status.value, + request_task.privacy_request_id, + request_task.id, + ) + + +def log_task_complete(request_task: RequestTask) -> None: + """Convenience method for logging task completion""" + logger.info( + "{} task {} is {}. Privacy Request: {}, Request Task {}", + request_task.action_type.value.capitalize(), + request_task.collection_address, + request_task.status.value, + request_task.privacy_request_id, + request_task.id, + ) + + +def _order_tasks_by_input_key( + input_keys: List[CollectionAddress], upstream_tasks: Query +) -> List[Optional[RequestTask]]: + """Order tasks by input key. If task doesn't exist, add None in its place + + Data being passed to GraphTask.access_request is expected to have the same order + as input keys so we know which data belongs to which upstream collection + """ + tasks: List[Optional[RequestTask]] = [] + for key in input_keys: + task = next( + ( + upstream + for upstream in upstream_tasks + if upstream.collection_address == key.value + ), + None, + ) + tasks.append(task) + return tasks + + +mapping = { + ActionType.access: run_access_node, + ActionType.erasure: run_erasure_node, + ActionType.consent: run_consent_node, +} + + +def queue_request_task( + request_task: RequestTask, privacy_request_proceed: bool = True +) -> None: + """Queues the RequestTask in Celery and caches the Celery Task ID""" + celery_task_fn: Task = mapping[request_task.action_type] + celery_task = celery_task_fn.delay( + privacy_request_id=request_task.privacy_request_id, + privacy_request_task_id=request_task.id, + privacy_request_proceed=privacy_request_proceed, + ) + cache_task_tracking_key(request_task.id, celery_task.task_id) + + +def log_task_queued(request_task: RequestTask, location: str) -> None: + """Helper for logging that tasks are queued""" + logger_method(request_task)( + "Queuing {} task {} from {}. Privacy Request: {}, Request Task {}", + request_task.action_type.value, + request_task.collection_address, + location, + request_task.privacy_request_id, + request_task.id, + ) diff --git a/src/fides/api/task/graph_runners.py b/src/fides/api/task/graph_runners.py new file mode 100644 index 0000000000..155910a269 --- /dev/null +++ b/src/fides/api/task/graph_runners.py @@ -0,0 +1,159 @@ +from typing import Any, Dict, List, Optional + +from loguru import logger +from sqlalchemy.orm import Session + +from fides.api.common_exceptions import PrivacyRequestExit +from fides.api.graph.graph import DatasetGraph +from fides.api.models.connectionconfig import ConnectionConfig +from fides.api.models.policy import Policy +from fides.api.models.privacy_request import PrivacyRequest +from fides.api.schemas.policy import ActionType +from fides.api.task.create_request_tasks import ( + run_access_request, + run_consent_request, + run_erasure_request, +) +from fides.api.task.deprecated_graph_task import ( + run_access_request_deprecated, + run_consent_request_deprecated, + run_erasure_request_deprecated, +) +from fides.api.util.collection_util import Row +from fides.config import CONFIG + + +def use_dsr_3_0_scheduler( + privacy_request: PrivacyRequest, action_type: ActionType +) -> bool: + """Return whether we should use the DSR 3.0 scheduler. + + Override if we have a partially processed Privacy Request that was already run on + DSR 2.0 so we can finish processing it on 2.0. + + """ + use_dsr_3_0 = CONFIG.execution.use_dsr_3_0 + + prev_results: Dict[ + str, Optional[List[Row]] + ] = privacy_request.get_raw_access_results() + existing_tasks_count: int = privacy_request.get_tasks_by_action(action_type).count() + + if prev_results and use_dsr_3_0 and not existing_tasks_count: + # If we've previously tried to process this Privacy Request using DSR 2.0, continue doing so + # for access and erasure requests + logger.info( + "Overriding scheduler to run privacy request {} using DSR 2.0 as it's " + "already partially processed", + privacy_request.id, + ) + use_dsr_3_0 = False + + return use_dsr_3_0 + + +def access_runner( + privacy_request: PrivacyRequest, + policy: Policy, + graph: DatasetGraph, + connection_configs: List[ConnectionConfig], + identity: Dict[str, Any], + session: Session, + privacy_request_proceed: bool = True, # Can be set to False in testing to run this in isolation +) -> Dict[str, List[Row]]: + """ + Access runner that temporarily supports running Access Requests with either DSR 3.0 or DSR 2.0 + + DSR 2.0 will be going away + """ + use_dsr_3_0 = use_dsr_3_0_scheduler(privacy_request, ActionType.access) + + if use_dsr_3_0: + run_access_request( + privacy_request=privacy_request, + policy=policy, + graph=graph, + connection_configs=connection_configs, + identity=identity, + session=session, + privacy_request_proceed=privacy_request_proceed, + ) + raise PrivacyRequestExit() + + return run_access_request_deprecated( + privacy_request=privacy_request, + policy=policy, + graph=graph, + connection_configs=connection_configs, + identity=identity, + session=session, + ) + + +def erasure_runner( + privacy_request: PrivacyRequest, + policy: Policy, + graph: DatasetGraph, + connection_configs: List[ConnectionConfig], + identity: Dict[str, Any], + access_request_data: Dict[str, List[Row]], + session: Session, + privacy_request_proceed: bool = True, # Can be set to False in testing to run this in isolation +) -> Dict[str, int]: + """Erasure runner that temporarily supports running Erasure DAGs with DSR 3.0 or 2.0. + + DSR 2.0 will be going away + """ + use_dsr_3_0 = use_dsr_3_0_scheduler(privacy_request, ActionType.erasure) + + if use_dsr_3_0: + run_erasure_request( + privacy_request=privacy_request, + session=session, + privacy_request_proceed=privacy_request_proceed, + ) + raise PrivacyRequestExit() + + return run_erasure_request_deprecated( + privacy_request=privacy_request, + policy=policy, + graph=graph, + connection_configs=connection_configs, + identity=identity, + access_request_data=access_request_data, + session=session, + ) + + +def consent_runner( + privacy_request: PrivacyRequest, + policy: Policy, + graph: DatasetGraph, + connection_configs: List[ConnectionConfig], + identity: Dict[str, Any], + session: Session, + privacy_request_proceed: bool = True, # Can be set to False in testing to run this in isolation +) -> Dict[str, bool]: + """Consent runner that temporarily supports running Consent DAGs with DSR 3.0 or 2.0. + + DSR 2.0 will be going away""" + use_dsr_3_0 = use_dsr_3_0_scheduler(privacy_request, ActionType.consent) + + if use_dsr_3_0: + run_consent_request( + privacy_request=privacy_request, + graph=graph, + identity=identity, + session=session, + privacy_request_proceed=privacy_request_proceed, + ) + raise PrivacyRequestExit() + + return run_consent_request_deprecated( + privacy_request=privacy_request, + policy=policy, + graph=graph, + connection_configs=connection_configs, + identity=identity, + session=session, + ) diff --git a/src/fides/api/task/graph_task.py b/src/fides/api/task/graph_task.py index 09b36fdab4..6a71dbbe00 100644 --- a/src/fides/api/task/graph_task.py +++ b/src/fides/api/task/graph_task.py @@ -1,15 +1,12 @@ # pylint: disable=too-many-lines import copy +import json import traceback from abc import ABC from functools import wraps from time import sleep -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import dask -from dask import delayed # type: ignore[attr-defined] -from dask.core import getcycle -from dask.threaded import get from loguru import logger from ordered_set import OrderedSet from sqlalchemy.orm import Session @@ -21,34 +18,38 @@ PrivacyRequestErasureEmailSendRequired, PrivacyRequestPaused, SkippingConsentPropagation, - TraversalError, -) -from fides.api.graph.analytics_events import ( - fideslog_graph_rerun, - prepare_rerun_graph_analytics_event, ) from fides.api.graph.config import ( ROOT_COLLECTION_ADDRESS, - TERMINATOR_ADDRESS, CollectionAddress, Field, FieldAddress, FieldPath, + GraphDataset, ) -from fides.api.graph.graph import DatasetGraph, Edge, Node -from fides.api.graph.graph_differences import format_graph_for_caching +from fides.api.graph.execution import ExecutionNode +from fides.api.graph.graph import DatasetGraph from fides.api.graph.traversal import Traversal, TraversalNode -from fides.api.models.connectionconfig import AccessLevel, ConnectionConfig +from fides.api.models.connectionconfig import ( + AccessLevel, + ConnectionConfig, + ConnectionType, +) +from fides.api.models.datasetconfig import DatasetConfig from fides.api.models.policy import Policy -from fides.api.models.privacy_request import ExecutionLogStatus, PrivacyRequest -from fides.api.models.sql_models import System # type: ignore[attr-defined] +from fides.api.models.privacy_request import ( + ExecutionLog, + ExecutionLogStatus, + PrivacyRequest, + RequestTask, +) from fides.api.schemas.policy import ActionType from fides.api.service.connectors.base_connector import BaseConnector from fides.api.task.consolidate_query_matches import consolidate_query_matches from fides.api.task.filter_element_match import filter_element_match from fides.api.task.refine_target_path import FieldPathNodeInput from fides.api.task.task_resources import TaskResources -from fides.api.util.cache import get_cache +from fides.api.util.cache import CustomJSONEncoder, get_cache from fides.api.util.collection_util import ( NodeInput, Row, @@ -56,18 +57,16 @@ extract_key_for_address, make_immutable, make_mutable, - partition, ) from fides.api.util.consent_util import add_errored_system_status_for_consent_reporting from fides.api.util.logger import Pii from fides.api.util.saas_util import FIDESOPS_GROUPED_INPUTS from fides.config import CONFIG -dask.config.set(scheduler="threads") - COLLECTION_FIELD_PATH_MAP = Dict[CollectionAddress, List[Tuple[FieldPath, FieldPath]]] EMPTY_REQUEST = PrivacyRequest() +EMPTY_REQUEST_TASK = RequestTask() def retry( @@ -107,17 +106,18 @@ def result(*args: Any, **kwargs: Any) -> Any: logger.warning( "Privacy request {} paused {}", method_name, - self.traversal_node.address, + self.execution_node.address, ) self.log_paused(action_type, ex) # Re-raise to stop privacy request execution on pause. raise except PrivacyRequestErasureEmailSendRequired as exc: traceback.print_exc() + self.request_task.rows_masked = 0 self.log_end(action_type, ex=None, success_override_msg=exc) self.resources.cache_erasure( f"{self.traversal_node.address.value}", 0 - ) # Cache that the erasure was performed in case we need to restart + ) # Cache that the erasure was performed in case we need to restart for DSR 2.0 return 0 except ( CollectionDisabled, @@ -127,7 +127,7 @@ def result(*args: Any, **kwargs: Any) -> Any: traceback.print_exc() logger.warning( "Skipping collection {} for privacy_request: {}", - self.traversal_node.address, + self.execution_node.address, self.resources.request.id, ) self.log_skipped(action_type, exc) @@ -136,7 +136,7 @@ def result(*args: Any, **kwargs: Any) -> Any: traceback.print_exc() logger.warning( "Skipping consent propagation on collection {} for privacy_request: {}", - self.traversal_node.address, + self.execution_node.address, self.resources.request.id, ) self.log_skipped(action_type, exc) @@ -154,61 +154,74 @@ def result(*args: Any, **kwargs: Any) -> Any: logger.warning( "Retrying {} {} in {} seconds...", method_name, - self.traversal_node.address, + self.execution_node.address, func_delay, ) sleep(func_delay) raised_ex = ex self.log_end(action_type, raised_ex) - self.resources.request.cache_failed_checkpoint_details( - step=action_type, collection=self.traversal_node.address - ) + self.resources.request.cache_failed_checkpoint_details(step=action_type) add_errored_system_status_for_consent_reporting( self.resources.session, self.resources.request, self.connector.configuration, ) - # Re-raise to stop privacy request execution on failure. - raise raised_ex # type: ignore + if not self.request_task.id: + # TODO Remove when we stop support for DSR 2.0 + # Re-raise to stop privacy request execution on failure for + # deprecated DSR 2.0 sequential execution + raise raised_ex # type: ignore + return default_return return result return decorator +def mark_current_and_downstream_nodes_as_failed( + privacy_request_task: RequestTask, db: Session +) -> None: + """ + For DSR 3.0, if the current node fails, mark it and *every descendant that can be reached by the current node* + as failed + """ + if not privacy_request_task.id: + return + + logger.info(f"Marking task {privacy_request_task.id} and descendants as errored") + + privacy_request_task.status = ExecutionLogStatus.error + db.add(privacy_request_task) + + for descendant_addr in privacy_request_task.all_descendant_tasks or []: + descendant: Optional[RequestTask] = ( + privacy_request_task.get_tasks_with_same_action_type(db, descendant_addr) + .filter(RequestTask.status == ExecutionLogStatus.pending) + .first() + ) + if not descendant: + continue + descendant.status = ExecutionLogStatus.error + db.add(descendant) + + db.commit() + + class GraphTask(ABC): # pylint: disable=too-many-instance-attributes """A task that operates on one traversal_node of a traversal""" def __init__( - self, traversal_node: TraversalNode, resources: TaskResources + self, resources: TaskResources ): # cache config, log config, db store config super().__init__() - self.traversal_node = traversal_node + self.request_task = resources.privacy_request_task + self.execution_node = ExecutionNode(resources.privacy_request_task) self.resources = resources self.connector: BaseConnector = resources.get_connector( - self.traversal_node.node.dataset.connection_key # ConnectionConfig.key - ) - self.data_uses: Set[str] = ( - System.get_data_uses( - [self.connector.configuration.system], include_parents=False - ) - if self.connector.configuration.system - else {} - ) - - # build incoming edges to the form : [dataset address: [(foreign field, local field)] - self.incoming_edges_by_collection: Dict[ - CollectionAddress, List[Edge] - ] = partition( - self.traversal_node.incoming_edges(), lambda e: e.f1.collection_address() - ) - - # the input keys this task will read from.These will build the dask graph - self.input_keys: List[CollectionAddress] = sorted( - self.incoming_edges_by_collection.keys() + self.execution_node.connection_key # ConnectionConfig.key ) - self.key = self.traversal_node.address + self.key: CollectionAddress = self.execution_node.address self.execution_log_id = None # a local copy of the execution log record written to. If we write multiple status @@ -218,56 +231,9 @@ def __init__( def __repr__(self) -> str: return f"{type(self)}:{self.key}" - @property - def grouped_fields(self) -> Set[str]: - """Convenience property - returns a set of fields that have been specified on the collection as dependent - upon one another - """ - return self.traversal_node.node.collection.grouped_inputs or set() - - @property - def dependent_identity_fields(self) -> bool: - """If the current collection needs inputs from other collections, in addition to its seed data.""" - collection = self.traversal_node.node.collection - for field in self.grouped_fields: - if collection.field(FieldPath(field)).identity: # type: ignore - return True - return False - - def build_incoming_field_path_maps( - self, group_dependent_fields: bool = False - ) -> Tuple[COLLECTION_FIELD_PATH_MAP, COLLECTION_FIELD_PATH_MAP]: - """ - For each collection connected to the current collection, return a list of tuples - mapping the foreign field to the local field. This is used to process data from incoming collections - into the current collection. - - :param group_dependent_fields: Whether we should split the incoming fields into two groups: one whose - fields are completely independent of one another, and the other whose incoming data needs to stay linked together. - If False, all fields are returned in the first tuple, and the second tuple just maps collections to an empty list. - - """ - - def field_map(keep: Callable) -> COLLECTION_FIELD_PATH_MAP: - return { - col_addr: [ - (edge.f1.field_path, edge.f2.field_path) - for edge in edge_list - if keep(edge.f2.field_path.string_path) - ] - for col_addr, edge_list in self.incoming_edges_by_collection.items() - } - - if group_dependent_fields: - return field_map( - lambda string_path: string_path not in self.grouped_fields - ), field_map(lambda string_path: string_path in self.grouped_fields) - - return field_map(lambda string_path: True), field_map(lambda string_path: False) - def generate_dry_run_query(self) -> Optional[str]: """Type-specific query generated for this traversal_node.""" - return self.connector.dry_run_query(self.traversal_node) + return self.connector.dry_run_query(self.execution_node) def can_write_data(self) -> bool: """Checks if the relevant ConnectionConfig has been granted "write" access to its data""" @@ -283,7 +249,7 @@ def _combine_seed_data( """Combine the seed data with the other dependent inputs. This is used when the seed data in a collection requires inputs from another collection to generate subsequent queries.""" # Get the identity values from the seeds that were passed into this collection. - seed_index = self.input_keys.index(ROOT_COLLECTION_ADDRESS) + seed_index = self.execution_node.input_keys.index(ROOT_COLLECTION_ADDRESS) seed_data = data[seed_index] for foreign_field_path, local_field_path in dependent_field_mappings[ @@ -321,11 +287,11 @@ def pre_process_input_data( The output dictionary is constructed with deduplicated values for each key, ensuring that the value lists and the fides_grouped_input list contain only unique elements. """ - if not len(data) == len(self.input_keys): + if not len(data) == len(self.execution_node.input_keys): logger.warning( "{} expected {} input keys, received {}", self, - len(self.input_keys), + len(self.execution_node.input_keys), len(data), ) @@ -335,14 +301,14 @@ def pre_process_input_data( ( independent_field_mappings, dependent_field_mappings, - ) = self.build_incoming_field_path_maps(group_dependent_fields) + ) = self.execution_node.build_incoming_field_path_maps(group_dependent_fields) for i, rowset in enumerate(data): - collection_address = self.input_keys[i] + collection_address = self.execution_node.input_keys[i] if ( group_dependent_fields - and self.dependent_identity_fields + and self.execution_node.dependent_identity_fields and collection_address == ROOT_COLLECTION_ADDRESS ): # Skip building data for the root collection if the seed data needs to be combined with other inputs @@ -350,7 +316,7 @@ def pre_process_input_data( logger.info( "Consolidating incoming data into {} from {}.", - self.traversal_node.node.address, + self.execution_node.address, collection_address, ) for row in rowset: @@ -376,7 +342,7 @@ def pre_process_input_data( ) grouped_data[local_field_path.string_path] = dependent_values - if self.dependent_identity_fields: + if self.execution_node.dependent_identity_fields: grouped_data = self._combine_seed_data( *data, grouped_data=grouped_data, @@ -394,21 +360,32 @@ def update_status( action_type: ActionType, status: ExecutionLogStatus, ) -> None: - """Update status activities""" - self.resources.write_execution_log( - self.traversal_node.node.dataset.connection_key, - self.traversal_node.address, - fields_affected, - action_type, - status, - msg, + """Update status activities - create an execution log (which stores historical logs) + and update the Request Task's current status. + """ + ExecutionLog.create( + db=self.resources.session, + data={ + "connection_key": self.execution_node.connection_key, + "dataset_name": self.execution_node.address.dataset, + "collection_name": self.execution_node.address.collection, + "fields_affected": fields_affected, + "action_type": action_type, + "status": status, + "privacy_request_id": self.resources.request.id, + "message": msg, + }, ) + if self.request_task.id: + # For DSR 3.0, updating the Request Task status when the ExecutionLog is + # created to keep these in sync. + # TODO remove conditional above alongside deprecating DSR 2.0 + self.request_task.update_status(self.resources.session, status) + def log_start(self, action_type: ActionType) -> None: """Task start activities""" - logger.info( - "Starting {}, traversal_node {}", self.resources.request.id, self.key - ) + logger.info("Starting {}, node {}", self.resources.request.id, self.key) self.update_status( "starting", [], action_type, ExecutionLogStatus.in_processing @@ -429,7 +406,8 @@ def log_paused(self, action_type: ActionType, ex: Optional[BaseException]) -> No def log_skipped(self, action_type: ActionType, ex: str) -> None: """Log that a collection was skipped. For now, this is because a collection has been disabled.""" logger.info("Skipping {}, node {}", self.resources.request.id, self.key) - + if action_type == ActionType.consent and self.request_task.id: + self.request_task.consent_sent = False self.update_status(str(ex), [], action_type, ExecutionLogStatus.skipped) def log_end( @@ -447,12 +425,18 @@ def log_end( Pii(ex), ) self.update_status(str(ex), [], action_type, ExecutionLogStatus.error) + # For DSR 3.0, Hooking into the GraphTask.log_end method to also mark the current + # Request Task and every Request Task that can be reached from the current + # task as errored. + mark_current_and_downstream_nodes_as_failed( + self.request_task, self.resources.session + ) else: logger.info("Ending {}, {}", self.resources.request.id, self.key) self.update_status( str(success_override_msg) if success_override_msg else "success", build_affected_field_logs( - self.traversal_node.node, self.resources.policy, action_type + self.execution_node, self.resources.policy, action_type ), action_type, ExecutionLogStatus.complete, @@ -478,10 +462,10 @@ def post_process_input_data( out: FieldPathNodeInput = {} for key, values in pre_processed_inputs.items(): path: FieldPath = FieldPath.parse(key) - field: Optional[Field] = self.traversal_node.node.collection.field(path) + field: Optional[Field] = self.execution_node.collection.field(path) if ( field - and path in self.traversal_node.query_field_paths + and path in self.execution_node.query_field_paths and isinstance(values, list) ): if field.return_all_elements: @@ -520,6 +504,18 @@ def access_results_post_processing( filter_element_match( row, query_paths=post_processed_node_input_data, delete_elements=False ) + + # For DSR 3.0, save data to build masking requests directly + # on the Request Task. + # Results saved with matching array elements preserved + if self.request_task.id: + self.request_task.data_for_erasures = json.dumps( + placeholder_output, cls=CustomJSONEncoder + ) + + # TODO Remove when we stop support for DSR 2.0 + # Save data to build masking requests for DSR 2.0 in Redis. + # Results saved with matching array elements preserved self.resources.cache_results_with_placeholders( f"access_request__{self.key}", placeholder_output ) @@ -528,9 +524,16 @@ def access_results_post_processing( for row in output: logger.info( "Filtering row in {} for matching array elements.", - self.traversal_node.node.address, + self.execution_node.address, ) filter_element_match(row, post_processed_node_input_data) + + if self.request_task.id: + # Saves intermediate access results for DSR 3.0 directly on the Request Task + self.request_task.access_data = json.dumps(output, cls=CustomJSONEncoder) + + # TODO Remove when we stop support for DSR 2.0 + # Saves intermediate access results for DSR 2.0 in Redis self.resources.cache_object(f"access_request__{self.key}", output) # Return filtered rows with non-matched array data removed. @@ -541,7 +544,7 @@ def skip_if_disabled(self) -> None: connection_config: ConnectionConfig = self.connector.configuration if connection_config.disabled: raise CollectionDisabled( - f"Skipping collection {self.traversal_node.node.address}. " + f"Skipping collection {self.execution_node.address}. " f"ConnectionConfig {connection_config.key} is disabled.", ) @@ -558,7 +561,7 @@ def skip_if_action_disabled(self, action_type: ActionType) -> None: and action_type not in connection_config.enabled_actions ): raise ActionDisabled( - f"Skipping collection {self.traversal_node.node.address}. " + f"Skipping collection {self.execution_node.address}. " f"The {action_type} action is disabled for connection config with key '{connection_config.key}'.", ) @@ -569,9 +572,10 @@ def access_request(self, *inputs: List[Row]) -> List[Row]: *inputs, group_dependent_fields=True ) output: List[Row] = self.connector.retrieve_data( - self.traversal_node, + self.execution_node, self.resources.policy, self.resources.request, + self.resources.privacy_request_task, formatted_input_data, ) filtered_output: List[Row] = self.access_results_post_processing( @@ -584,32 +588,40 @@ def access_request(self, *inputs: List[Row]) -> List[Row]: def erasure_request( self, retrieved_data: List[Row], - inputs: List[List[Row]], - *erasure_prereqs: int, + *erasure_prereqs: int, # TODO Remove when we stop support for DSR 2.0. DSR 3.0 enforces with downstream_tasks. ) -> int: """Run erasure request""" # if there is no primary key specified in the graph node configuration # note this in the execution log and perform no erasures on this node - if not self.traversal_node.node.contains_field(lambda f: f.primary_key): + if not self.execution_node.collection.contains_field(lambda f: f.primary_key): logger.warning( "No erasures on {} as there is no primary_key defined.", - self.traversal_node.node.address, + self.execution_node.address, ) + if self.request_task.id: + # For DSR 3.0, largely for testing. DSR 3.0 uses Request Task status + # instead of presence of cached erasure data to know if we should rerun a node + self.request_task.rows_masked = 0 # Saved as part of update_status + # TODO Remove when we stop support for DSR 2.0 + self.resources.cache_erasure(self.key.value, 0) self.update_status( "No values were erased since no primary key was defined for this collection", None, ActionType.erasure, ExecutionLogStatus.complete, ) - # Cache that the erasure was performed in case we need to restart - self.resources.cache_erasure(self.key.value, 0) return 0 if not self.can_write_data(): logger.warning( "No erasures on {} as its ConnectionConfig does not have write access.", - self.traversal_node.node.address, + self.execution_node.address, ) + if self.request_task.id: + # DSR 3.0 + self.request_task.rows_masked = 0 # Saved as part of update_status + # TODO Remove when we stop support for DSR 2.0 + self.resources.cache_erasure(self.key.value, 0) self.update_status( f"No values were erased since this connection {self.connector.configuration.key} has not been " f"given write access", @@ -617,24 +629,26 @@ def erasure_request( ActionType.erasure, ExecutionLogStatus.error, ) - self.resources.cache_erasure(self.key.value, 0) return 0 - formatted_input_data: NodeInput = self.pre_process_input_data( - *inputs, group_dependent_fields=True - ) - output = self.connector.mask_data( - self.traversal_node, + self.execution_node, self.resources.policy, self.resources.request, + self.resources.privacy_request_task, retrieved_data, - formatted_input_data, ) - self.log_end(ActionType.erasure) + if self.request_task.id: + # For DSR 3.0, largely for testing. DSR 3.0 uses Request Task status + # instead of presence of cached erasure data to know if we should rerun a node + self.request_task.rows_masked = ( + output # Saved as part of update_status below + ) + # TODO Remove when we stop support for DSR 2.0 self.resources.cache_erasure( - f"{self.key}", output + self.key.value, output ) # Cache that the erasure was performed in case we need to restart + self.log_end(ActionType.erasure) return output @retry(action_type=ActionType.consent, default_return=False) @@ -643,24 +657,29 @@ def consent_request(self, identity: Dict[str, Any]) -> bool: if not self.can_write_data(): logger.warning( "No consent on {} as its ConnectionConfig does not have write access.", - self.traversal_node.node.address, + self.execution_node.address, ) + if self.request_task.id: + # For DSR 3.0, saved as part of + self.request_task.consent_sent = False self.update_status( - f"No values were erased since this connection {self.connector.configuration.key} has not been " + f"No consent requests were sent since this connection {self.connector.configuration.key} has not been " f"given write access", None, - ActionType.erasure, + ActionType.consent, ExecutionLogStatus.error, ) return False output: bool = self.connector.run_consent_request( - self.traversal_node, + self.execution_node, self.resources.policy, self.resources.request, + self.resources.privacy_request_task, identity, self.resources.session, ) + self.request_task.consent_sent = output self.log_end(ActionType.consent) return output @@ -674,55 +693,15 @@ def collect_queries_fn( tn: TraversalNode, data: Dict[CollectionAddress, str] ) -> None: if not tn.is_root_node(): - data[tn.address] = GraphTask(tn, resources).generate_dry_run_query() # type: ignore + # Mock a RequestTask object in memory + resources.privacy_request_task = tn.to_mock_request_task() + data[tn.address] = GraphTask(resources).generate_dry_run_query() # type: ignore env: Dict[CollectionAddress, str] = {} traversal.traverse(env, collect_queries_fn) return env -def update_mapping_from_cache( - dsk: Dict[CollectionAddress, Tuple[Any, ...]], - resources: TaskResources, - start_fn: Callable, -) -> None: - """When resuming a privacy request from a paused or failed state, update the `dsk` dictionary with results we've - already obtained from a previous run. Remove upstream dependencies for these nodes, and just return the data we've - already retrieved, rather than visiting them again. - - If there's no cached data, the dsk dictionary won't change. - """ - - cached_results: Dict[str, Optional[List[Row]]] = resources.get_all_cached_objects() - - for collection_name in cached_results: - dsk[CollectionAddress.from_string(collection_name)] = ( - start_fn(cached_results[collection_name]), - ) - - -def _format_data_use_map_for_caching( - env: Dict[CollectionAddress, "GraphTask"] -) -> Dict[str, Set[str]]: - """ - Create a map of `Collection`s mapped to their associated `DataUse`s - to be stored in the cache. This is done before request execution, so that we - maintain the _original_ state of the graph as it's used for request execution. - The graph is subject to change "from underneath" the request execution runtime, - but we want to avoid picking up those changes in our data use map. - - `DataUse`s are associated with a `Collection` by means of the `System` - that's linked to a `Collection`'s `Connection` definition. - - Example: - { - : {"data_use_1", "data_use_2"}, - : {"data_use_1"}, - } - """ - return {collection.value: g_task.data_uses for collection, g_task in env.items()} - - def start_function(seed: List[Dict[str, Any]]) -> Callable[[], List[Dict[str, Any]]]: """Return a function for collections with no upstream dependencies, that just start with seed data. @@ -735,69 +714,6 @@ def g() -> List[Dict[str, Any]]: return g -async def run_access_request( - privacy_request: PrivacyRequest, - policy: Policy, - graph: DatasetGraph, - connection_configs: List[ConnectionConfig], - identity: Dict[str, Any], - session: Session, -) -> Dict[str, List[Row]]: - """Run the access request""" - traversal: Traversal = Traversal(graph, identity) - with TaskResources( - privacy_request, policy, connection_configs, session - ) as resources: - - def collect_tasks_fn( - tn: TraversalNode, data: Dict[CollectionAddress, GraphTask] - ) -> None: - """Run the traversal, as an action creating a GraphTask for each traversal_node.""" - if not tn.is_root_node(): - data[tn.address] = GraphTask(tn, resources) - - def termination_fn( - *dependent_values: List[Row], - ) -> Dict[str, Optional[List[Row]]]: - """A termination function that just returns its inputs mapped to their source addresses. - This needs to wait for all dependent keys because this is how dask is informed to wait for - all terminating addresses before calling this.""" - - return resources.get_all_cached_objects() - - env: Dict[CollectionAddress, Any] = {} - end_nodes = traversal.traverse(env, collect_tasks_fn) - - dsk: Dict[CollectionAddress, Tuple[Any, ...]] = { - k: (t.access_request, *t.input_keys) for k, t in env.items() - } - dsk[ROOT_COLLECTION_ADDRESS] = (start_function([traversal.seed_data]),) - dsk[TERMINATOR_ADDRESS] = (termination_fn, *end_nodes) - update_mapping_from_cache(dsk, resources, start_function) - - await fideslog_graph_rerun( - prepare_rerun_graph_analytics_event( - privacy_request, env, end_nodes, resources, ActionType.access - ) - ) - - # cache access graph for use in logging/analytics event - privacy_request.cache_access_graph(format_graph_for_caching(env, end_nodes)) - - # cache a map of collections -> data uses for the output package of access requests - # this is cached here before request execution, since this is the state of the - # graph used for request execution. the graph could change _during_ request execution, - # but we don't want those changes in our data use map. - privacy_request.cache_data_use_map(_format_data_use_map_for_caching(env)) - - v = delayed(get(dsk, TERMINATOR_ADDRESS, num_workers=1)) - access_results = v.compute() - filtered_access_results = filter_by_enabled_actions( - access_results, connection_configs - ) - return filtered_access_results - - def filter_by_enabled_actions( access_results: Dict[str, Any], connection_configs: List[ConnectionConfig] ) -> Dict[str, Any]: @@ -839,172 +755,8 @@ def get_cached_data_for_erasures( } -def update_erasure_mapping_from_cache( - dsk: Dict[CollectionAddress, Union[Tuple[Any, ...], int]], resources: TaskResources -) -> None: - """On pause or restart from failure, update the dsk graph to skip running erasures on collections - we've already visited. Instead, just return the previous count of rows affected. - - If there's no cached data, the dsk dictionary won't change. - """ - cached_erasures: Dict[str, int] = resources.get_all_cached_erasures() - - for collection_name in cached_erasures: - dsk[CollectionAddress.from_string(collection_name)] = cached_erasures[ - collection_name - ] - - -async def run_erasure( # pylint: disable = too-many-arguments - privacy_request: PrivacyRequest, - policy: Policy, - graph: DatasetGraph, - connection_configs: List[ConnectionConfig], - identity: Dict[str, Any], - access_request_data: Dict[str, List[Row]], - session: Session, -) -> Dict[str, int]: - """Run an erasure request""" - traversal: Traversal = Traversal(graph, identity) - with TaskResources( - privacy_request, policy, connection_configs, session - ) as resources: - - def collect_tasks_fn( - tn: TraversalNode, data: Dict[CollectionAddress, GraphTask] - ) -> None: - """Run the traversal, as an action creating a GraphTask for each traversal_node.""" - if not tn.is_root_node(): - data[tn.address] = GraphTask(tn, resources) - - # We store the end nodes from the traversal for analytics purposes - # but we generate a separate erasure_end_nodes list for the actual erasure traversal - env: Dict[CollectionAddress, Any] = {} - access_end_nodes = traversal.traverse(env, collect_tasks_fn) - erasure_end_nodes = list(graph.nodes.keys()) - - def termination_fn(*dependent_values: int) -> Dict[str, int]: - """ - The erasure order can be affected in a way that not every node is directly linked - to the termination node. This means that we can't just aggregate the inputs directly, - we must read the erasure results from the cache. - """ - return resources.get_all_cached_erasures() - - access_request_data[ROOT_COLLECTION_ADDRESS.value] = [identity] - - dsk: Dict[CollectionAddress, Any] = { - k: ( - t.erasure_request, - access_request_data.get( - str(k), [] - ), # Pass in the results of the access request for this collection - [ - access_request_data.get( - str(upstream_key), [] - ) # Additionally pass in the original input data we used for the access request. It's helpful in - # cases like the EmailConnector where the access request doesn't actually retrieve data. - for upstream_key in t.input_keys - ], - *_evaluate_erasure_dependencies(t, erasure_end_nodes), - ) - for k, t in env.items() - } - - # root node returns 0 to be consistent with the output of the other erasure tasks - dsk[ROOT_COLLECTION_ADDRESS] = 0 - # terminator function reads and returns the cached erasure results for the entire erasure traversal - dsk[TERMINATOR_ADDRESS] = (termination_fn, *erasure_end_nodes) - update_erasure_mapping_from_cache(dsk, resources) - await fideslog_graph_rerun( - prepare_rerun_graph_analytics_event( - privacy_request, env, access_end_nodes, resources, ActionType.erasure - ) - ) - - # using an existing function from dask.core to detect cycles in the generated graph - collection_cycle = getcycle(dsk, None) - if collection_cycle: - raise TraversalError( - f"The values for the `erase_after` fields caused a cycle in the following collections {collection_cycle}" - ) - - v = delayed(get(dsk, TERMINATOR_ADDRESS, num_workers=1)) - return v.compute() - - -def _evaluate_erasure_dependencies( - t: GraphTask, end_nodes: List[CollectionAddress] -) -> Set[CollectionAddress]: - """ - Return a set of collection addresses corresponding to collections that need - to be erased before the given task. Remove the dependent collection addresses - from `end_nodes` so they can be executed in the correct order. If a task does - not have any dependencies it is linked directly to the root node - """ - erase_after = t.traversal_node.node.collection.erase_after - for collection in erase_after: - if collection in end_nodes: - # end_node list is modified in place - end_nodes.remove(collection) - # this task will execute after the collections in `erase_after` or - # execute at the beginning by linking it to the root node - return erase_after if len(erase_after) else {ROOT_COLLECTION_ADDRESS} - - -async def run_consent_request( # pylint: disable = too-many-arguments - privacy_request: PrivacyRequest, - policy: Policy, - graph: DatasetGraph, - connection_configs: List[ConnectionConfig], - identity: Dict[str, Any], - session: Session, -) -> Dict[str, bool]: - """Run a consent request - - The graph built is very simple: there are no relationships between the nodes, every node has - identity data input and every node outputs whether the consent request succeeded. - - The DatasetGraph passed in is expected to have one Node per Dataset. That Node is expected to carry out requests - for the Dataset as a whole. - """ - - with TaskResources( - privacy_request, policy, connection_configs, session - ) as resources: - graph_keys: List[CollectionAddress] = list(graph.nodes.keys()) - dsk: Dict[CollectionAddress, Any] = {} - - for col_address, node in graph.nodes.items(): - traversal_node = TraversalNode(node) - task = GraphTask(traversal_node, resources) - dsk[col_address] = (task.consent_request, identity) - - def termination_fn(*dependent_values: bool) -> Tuple[bool, ...]: - """The dependent_values here is an bool output from each task feeding in, where - each task reports the output of 'task.consent_request(identity_data)', which is whether the - consent request succeeded - - The termination function just returns this tuple of booleans.""" - return dependent_values - - # terminator function waits for all keys - dsk[TERMINATOR_ADDRESS] = (termination_fn, *graph_keys) - - v = delayed(get(dsk, TERMINATOR_ADDRESS, num_workers=1)) - - update_successes: Tuple[bool, ...] = v.compute() - # we combine the output of the termination function with the input keys to provide - # a map of {collection_name: whether consent request succeeded}: - consent_update_map: Dict[str, bool] = dict( - zip([coll.value for coll in graph_keys], update_successes) - ) - - return consent_update_map - - def build_affected_field_logs( - node: Node, policy: Policy, action_type: ActionType + node: ExecutionNode, policy: Policy, action_type: ActionType ) -> List[Dict[str, Any]]: """For a given node (collection), policy, and action_type (access or erasure) format all of the fields that were potentially touched to be stored in the ExecutionLogs for troubleshooting. @@ -1050,3 +802,28 @@ def build_affected_field_logs( ) return ret + + +def build_consent_dataset_graph(datasets: List[DatasetConfig]) -> DatasetGraph: + """ + Build the starting DatasetGraph for consent requests. + + Consent Graph has one node per dataset. Nodes must be of saas type and have consent requests defined. + """ + consent_datasets: List[GraphDataset] = [] + + for dataset_config in datasets: + connection_type: ConnectionType = ( + dataset_config.connection_config.connection_type # type: ignore + ) + saas_config: Optional[Dict] = dataset_config.connection_config.saas_config + if ( + connection_type == ConnectionType.saas + and saas_config + and saas_config.get("consent_requests") + ): + consent_datasets.append( + dataset_config.get_dataset_with_stubbed_collection() # type: ignore[arg-type, assignment] + ) + + return DatasetGraph(*consent_datasets) diff --git a/src/fides/api/task/task_resources.py b/src/fides/api/task/task_resources.py index 7bf86f37d6..179ab04109 100644 --- a/src/fides/api/task/task_resources.py +++ b/src/fides/api/task/task_resources.py @@ -5,21 +5,14 @@ from sqlalchemy.orm import Session from fides.api.common_exceptions import ConnectorNotFoundException -from fides.api.graph.config import CollectionAddress from fides.api.models.connectionconfig import ConnectionConfig, ConnectionType from fides.api.models.policy import Policy -from fides.api.models.privacy_request import ( - ExecutionLog, - ExecutionLogStatus, - PrivacyRequest, -) -from fides.api.schemas.policy import ActionType +from fides.api.models.privacy_request import PrivacyRequest, RequestTask from fides.api.service.connectors import ( BaseConnector, BigQueryConnector, DynamoDBConnector, FidesConnector, - ManualConnector, MariaDBConnector, MicrosoftSQLServerConnector, MongoDBConnector, @@ -75,8 +68,6 @@ def build_connector( # pylint: disable=R0911,R0912 return BigQueryConnector(connection_config) if connection_config.connection_type == ConnectionType.saas: return SaaSConnector(connection_config) - if connection_config.connection_type == ConnectionType.manual: - return ManualConnector(connection_config) if connection_config.connection_type == ConnectionType.timescale: return TimescaleConnector(connection_config) if connection_config.connection_type == ConnectionType.dynamodb: @@ -96,12 +87,14 @@ def close(self) -> None: class TaskResources: - """Shared information and environment for all nodes of a given task. + """Holds some Database resources for the given task. + Importantly, should be used as a context manager, to close connections to external databases. + This includes - the privacy request + - the request task - the policy - - redis connection - - configurations to any outside resources the task will require to run + - configurations to any outside resources the task will require to run """ def __init__( @@ -109,12 +102,15 @@ def __init__( request: PrivacyRequest, policy: Policy, connection_configs: List[ConnectionConfig], + privacy_request_task: RequestTask, session: Session, ): self.request = request + self.policy = policy + # TODO Remove when we stop support for DSR 2.0 self.cache = get_cache() - # tbd populate connection configurations. + self.privacy_request_task = privacy_request_task self.connection_configs: Dict[str, ConnectionConfig] = { c.key: c for c in connection_configs } @@ -129,6 +125,7 @@ def __exit__(self, _type: Any, value: Any, traceback: Any) -> None: """Support 'with' usage for closing resources""" self.close() + # TODO Remove when we stop support for DSR 2.0 def cache_results_with_placeholders(self, key: str, value: Any) -> None: """Cache raw results from node. Object will be stored in redis under 'PLACEHOLDER_RESULTS__PRIVACY_REQUEST_ID__TYPE__COLLECTION_ADDRESS @@ -137,10 +134,12 @@ def cache_results_with_placeholders(self, key: str, value: Any) -> None: f"PLACEHOLDER_RESULTS__{self.request.id}__{key}", value ) + # TODO Remove when we stop support for DSR 2.0 def cache_object(self, key: str, value: Any) -> None: """Store in cache. Object will be stored in redis under 'REQUEST_ID__TYPE__ADDRESS'""" self.cache.set_encoded_object(f"{self.request.id}__{key}", value) + # TODO Remove when we stop support for DSR 2.0 def get_all_cached_objects(self) -> Dict[str, Optional[List[Row]]]: """Retrieve the access results of all steps (cache_object)""" value_dict = self.cache.get_encoded_objects_by_prefix( @@ -153,6 +152,7 @@ def get_all_cached_objects(self) -> Dict[str, Optional[List[Row]]]: for k, v in value_dict.items() } + # TODO Remove when we stop support for DSR 2.0 def cache_erasure(self, key: str, value: int) -> None: """Cache that a node's masking is complete. Object will be stored in redis under 'REQUEST_ID__erasure_request__ADDRESS @@ -161,6 +161,7 @@ def cache_erasure(self, key: str, value: int) -> None: f"{self.request.id}__erasure_request__{key}", value ) + # TODO Remove when we stop support for DSR 2.0 def get_all_cached_erasures(self) -> Dict[str, int]: """Retrieve which collections have been masked and their row counts(cache_erasure)""" value_dict = self.cache.get_encoded_objects_by_prefix( @@ -170,32 +171,6 @@ def get_all_cached_erasures(self) -> Dict[str, int]: number_of_leading_strings_to_exclude = 2 return {extract_key_for_address(k, number_of_leading_strings_to_exclude): v for k, v in value_dict.items()} # type: ignore - def write_execution_log( # pylint: disable=too-many-arguments - self, - connection_key: str, - collection_address: CollectionAddress, - fields_affected: Any, - action_type: ActionType, - status: ExecutionLogStatus, - message: str = None, - ) -> Any: - """Store in application db. Return the created or written-to id field value.""" - db = self.session - - ExecutionLog.create( - db=db, - data={ - "connection_key": connection_key, - "dataset_name": collection_address.dataset, - "collection_name": collection_address.collection, - "fields_affected": fields_affected, - "action_type": action_type, - "status": status, - "privacy_request_id": self.request.id, - "message": message, - }, - ) - def get_connector(self, key: FidesKey) -> Any: """Create or return the client corresponding to the given ConnectionConfig key""" if key in self.connection_configs: @@ -203,6 +178,11 @@ def get_connector(self, key: FidesKey) -> Any: raise ConnectorNotFoundException(f"No available connector for {key}") def close(self) -> None: - """Close any held resources""" + """Close any held resources + + Note that using TaskResources as a Connection Manager will use this + self.connections.close() to close connections to External Databases. This is + really important to avoid opening up too many connections. + """ logger.debug("Closing all task resources for {}", self.request.id) self.connections.close() diff --git a/src/fides/api/util/cache.py b/src/fides/api/util/cache.py index a981a2f307..791a8ff4e0 100644 --- a/src/fides/api/util/cache.py +++ b/src/fides/api/util/cache.py @@ -9,9 +9,11 @@ from redis import Redis from redis.client import Script # type: ignore from redis.exceptions import ConnectionError as ConnectionErrorFromRedis +from redis.exceptions import DataError from fides.api import common_exceptions from fides.api.schemas.masking.masking_secrets import SecretType +from fides.api.tasks import celery_app from fides.config import CONFIG # This constant represents every type a redis key may contain, and can be @@ -240,3 +242,61 @@ def get_all_cache_keys_for_privacy_request(privacy_request_id: str) -> List[Any] def get_async_task_tracking_cache_key(privacy_request_id: str) -> str: return f"id-{privacy_request_id}-async-execution" + + +def cache_task_tracking_key(request_id: str, celery_task_id: str) -> None: + """ + Cache the celery task id created to run the Privacy Request or Request Task. + + Note that it is possible a Privacy Request or Request Task is queued multiple times + over the life of a Priavcy Request so the cached id is the latest task queued + + :param request_id: Can be the Privacy Request Id or a Request Task ID - these are cached in the same place. + :param celery_task_id: The id of the Celery task itself that was queued to run the + Privacy Request or the Request Task + :return: None + """ + + cache: FidesopsRedis = get_cache() + + try: + cache.set( + get_async_task_tracking_cache_key(request_id), + celery_task_id, + ) + except DataError: + logger.debug( + "Error tracking task_id for privacy request or request task with id {}", + request_id, + ) + + +def celery_tasks_in_flight(celery_task_ids: List[str]) -> bool: + """Returns True if supplied Celery Tasks appear to be in-flight""" + if not celery_task_ids: + return False + + queried_tasks = celery_app.control.inspect().query_task(*celery_task_ids) + if not queried_tasks: + return False + + # Expected format: {HOSTNAME: {TASK_ID: [STATE, TASK_INFO]}} + for _, task_details in queried_tasks.items(): + for _, state_array in task_details.items(): + state: str = state_array[0] + # Note, not positive of states here, + # some seen in testing, some from here: + # https://github.com/celery/celery/blob/main/celery/worker/control.py or + # https://github.com/celery/celery/blob/main/celery/states.py + if state in [ + "active", + "received", + "registered", + "reserved", + "retry", + "scheduled", + "started", + ]: + return True + + return False diff --git a/src/fides/api/util/saas_util.py b/src/fides/api/util/saas_util.py index 3b2d7d1735..af2517ff88 100644 --- a/src/fides/api/util/saas_util.py +++ b/src/fides/api/util/saas_util.py @@ -424,7 +424,6 @@ def get_identity(privacy_request: Optional[PrivacyRequest]) -> Optional[str]: if not privacy_request: return None - identities: List[str] = [] identity_data: Dict[str, Any] = privacy_request.get_cached_identity_data() # filters out keys where associated value is None or empty str identities = list({k for k, v in identity_data.items() if v}) diff --git a/src/fides/api/worker/__init__.py b/src/fides/api/worker/__init__.py index 543d1993eb..7764005c68 100644 --- a/src/fides/api/worker/__init__.py +++ b/src/fides/api/worker/__init__.py @@ -1,6 +1,6 @@ import json - from typing import Any + from celery import VERSION_BANNER from celery.apps.worker import Worker from celery.signals import celeryd_after_setup @@ -38,9 +38,7 @@ def log_celery_setup(sender: str, instance: Worker, **kwargs: Any) -> None: "queues": "|".join(str(queue) for queue in app.amqp.queues.keys()), } - logger.bind( - celery_details=celery_details - ).info("Celery connection setup complete") + logger.bind(celery_details=celery_details).info("Celery connection setup complete") if __name__ == "__main__": # pragma: no cover diff --git a/src/fides/common/api/v1/urn_registry.py b/src/fides/common/api/v1/urn_registry.py index e788ade5e1..b1b5fabb18 100644 --- a/src/fides/common/api/v1/urn_registry.py +++ b/src/fides/common/api/v1/urn_registry.py @@ -73,10 +73,11 @@ PRIVACY_REQUEST_BULK_RETRY = "/privacy-request/bulk/retry" PRIVACY_REQUEST_DENY = "/privacy-request/administrate/deny" REQUEST_STATUS_LOGS = "/privacy-request/{privacy_request_id}/log" +REQUEST_TASKS = "/privacy-request/{privacy_request_id}/tasks" +PRIVACY_REQUEST_REQUEUE = "/privacy-request/{privacy_request_id}/requeue" + PRIVACY_REQUEST_VERIFY_IDENTITY = "/privacy-request/{privacy_request_id}/verify" PRIVACY_REQUEST_RESUME = "/privacy-request/{privacy_request_id}/resume" -PRIVACY_REQUEST_MANUAL_INPUT = "/privacy-request/{privacy_request_id}/manual_input" -PRIVACY_REQUEST_MANUAL_ERASURE = "/privacy-request/{privacy_request_id}/erasure_confirm" PRIVACY_REQUEST_NOTIFICATIONS = "/privacy-request/notification" PRIVACY_REQUEST_RETRY = "/privacy-request/{privacy_request_id}/retry" REQUEST_PREVIEW = "/privacy-request/preview" diff --git a/src/fides/config/execution_settings.py b/src/fides/config/execution_settings.py index e0f5e11ded..7bcad29b08 100644 --- a/src/fides/config/execution_settings.py +++ b/src/fides/config/execution_settings.py @@ -48,6 +48,18 @@ class ExecutionSettings(FidesSettings): default=False, description="Allows custom privacy request fields to be used in request execution.", ) + request_task_ttl: int = Field( + default=604800, + description="The number of seconds a request task should live.", + ) + state_polling_interval: int = Field( + default=30, + description="Seconds between polling for Privacy Requests that should change state", + ) + use_dsr_3_0: bool = Field( + default=False, + description="Temporary flag to switch to using DSR 3.0 to process your tasks.", + ) class Config: env_prefix = ENV_PREFIX diff --git a/tests/conftest.py b/tests/conftest.py index 5d5aefa572..ec616c2aeb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import asyncio import json import os +import time from datetime import datetime from pathlib import Path from typing import Callable @@ -9,6 +10,7 @@ import pytest import requests import yaml +from fastapi import Query from fastapi.testclient import TestClient from fideslang import DEFAULT_TAXONOMY, models from httpx import AsyncClient @@ -17,6 +19,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker +from fides.api.common_exceptions import PrivacyRequestExit from fides.api.cryptography.schemas.jwt import ( JWE_ISSUED_AT, JWE_PAYLOAD_CLIENT_ID, @@ -26,18 +29,17 @@ ) from fides.api.db.ctl_session import sync_engine from fides.api.main import app -from fides.api.models.privacy_request import generate_request_callback_jwe -from fides.api.models.sql_models import Cookies, DataUse, PrivacyDeclaration -from fides.api.oauth.jwt import generate_jwe -from fides.api.oauth.roles import ( - APPROVER, - CONTRIBUTOR, - OWNER, - VIEWER, - VIEWER_AND_APPROVER, +from fides.api.models.privacy_request import ( + EXITED_EXECUTION_LOG_STATUSES, + generate_request_callback_jwe, ) +from fides.api.models.sql_models import Cookies, DataUse +from fides.api.oauth.jwt import generate_jwe +from fides.api.oauth.roles import APPROVER, CONTRIBUTOR, OWNER, VIEWER_AND_APPROVER from fides.api.schemas.messaging.messaging import MessagingServiceType +from fides.api.task.graph_runners import access_runner, consent_runner, erasure_runner from fides.api.util.cache import get_cache +from fides.api.util.collection_util import Row from fides.common.api.scope_registry import SCOPE_REGISTRY from fides.config import get_config from fides.config.config_proxy import ConfigProxy @@ -625,6 +627,121 @@ def run_privacy_request_task(celery_session_app): ] +class DSRThreeTestRunnerTimedOut(Exception): + """DSR 3.0 Test Runner Timed Out""" + + +def wait_for_tasks_to_complete( + db: Session, pr: PrivacyRequest, action_type: ActionType +): + """Testing Helper for DSR 3.0 - repeatedly checks to see if all Request Tasks + have exited so bogged down test doesn't hang""" + + def all_tasks_have_run(tasks: Query) -> bool: + return all(tsk.status in EXITED_EXECUTION_LOG_STATUSES for tsk in tasks) + + db.commit() + counter = 0 + while not all_tasks_have_run( + ( + db.query(RequestTask).filter( + RequestTask.privacy_request_id == pr.id, + RequestTask.action_type == action_type, + ) + ) + ): + time.sleep(1) + counter += 1 + if counter == 5: + raise DSRThreeTestRunnerTimedOut() + + +def access_runner_tester( + privacy_request: PrivacyRequest, + policy: Policy, + graph: DatasetGraph, + connection_configs: List[ConnectionConfig], + identity: Dict[str, Any], + session: Session, +): + """ + Function for testing the access request for either DSR 2.0 and DSR 3.0 + """ + try: + return access_runner( + privacy_request, + policy, + graph, + connection_configs, + identity, + session, + privacy_request_proceed=False, # This allows the DSR 3.0 Access Runner to be tested in isolation, to just test running the access graph without queuing the privacy request + ) + except PrivacyRequestExit: + # DSR 3.0 intentionally raises a PrivacyRequestExit status while it waits for + # RequestTasks to finish + wait_for_tasks_to_complete(session, privacy_request, ActionType.access) + return privacy_request.get_raw_access_results() + + +def erasure_runner_tester( + privacy_request: PrivacyRequest, + policy: Policy, + graph: DatasetGraph, + connection_configs: List[ConnectionConfig], + identity: Dict[str, Any], + access_request_data: Dict[str, List[Row]], + session: Session, +): + """ + Function for testing the erasure runner for either DSR 2.0 and DSR 3.0 + """ + try: + return erasure_runner( + privacy_request, + policy, + graph, + connection_configs, + identity, + access_request_data, + session, + privacy_request_proceed=False, # This allows the DSR 3.0 Erasure Runner to be tested in isolation + ) + except PrivacyRequestExit: + # DSR 3.0 intentionally raises a PrivacyRequestExit status while it waits + # for RequestTasks to finish + wait_for_tasks_to_complete(session, privacy_request, ActionType.erasure) + return privacy_request.get_raw_masking_counts() + + +def consent_runner_tester( + privacy_request: PrivacyRequest, + policy: Policy, + graph: DatasetGraph, + connection_configs: List[ConnectionConfig], + identity: Dict[str, Any], + session: Session, +): + """ + Function for testing the consent request for either DSR 2.0 and DSR 3.0 + """ + try: + return consent_runner( + privacy_request, + policy, + graph, + connection_configs, + identity, + session, + privacy_request_proceed=False, # This allows the DSR 3.0 Consent Runner to be tested in isolation, to just test running the consent graph without queuing the privacy request + ) + except PrivacyRequestExit: + # DSR 3.0 intentionally raises a PrivacyRequestExit status while it waits for + # RequestTasks to finish + wait_for_tasks_to_complete(session, privacy_request, ActionType.consent) + return privacy_request.get_consent_results() + + @pytest.fixture(autouse=True, scope="session") def analytics_opt_out(): """Disable sending analytics when running tests.""" diff --git a/tests/fixtures/application_fixtures.py b/tests/fixtures/application_fixtures.py index 3fe2e1a4ce..3b1474a7fa 100644 --- a/tests/fixtures/application_fixtures.py +++ b/tests/fixtures/application_fixtures.py @@ -16,6 +16,7 @@ from toml import load as load_toml from fides.api.common_exceptions import SystemManagerException +from fides.api.graph.graph import DatasetGraph from fides.api.models.application_config import ApplicationConfig from fides.api.models.audit_log import AuditLog, AuditLogAction from fides.api.models.client import ClientDetail @@ -24,7 +25,7 @@ ConnectionConfig, ConnectionType, ) -from fides.api.models.datasetconfig import DatasetConfig +from fides.api.models.datasetconfig import DatasetConfig, convert_dataset_to_graph from fides.api.models.fides_user import FidesUser from fides.api.models.fides_user_permissions import FidesUserPermissions from fides.api.models.messaging import MessagingConfig @@ -61,6 +62,7 @@ PrivacyRequest, PrivacyRequestStatus, ProvidedIdentity, + RequestTask, ) from fides.api.models.registration import UserRegistration from fides.api.models.sql_models import DataCategory as DataCategoryDbModel @@ -98,6 +100,7 @@ from fides.api.util.data_category import DataCategory from fides.config import CONFIG from fides.config.helpers import load_file +from tests.ops.test_helpers.cache_secrets_helper import clear_cache_identities logging.getLogger("faker").setLevel(logging.ERROR) # disable verbose faker logging @@ -174,6 +177,7 @@ def mock_upload_logic() -> Generator: with mock.patch( "fides.api.service.storage.storage_uploader_service.upload_to_s3" ) as _fixture: + _fixture.return_value = "http://www.data-download-url" yield _fixture @@ -1320,6 +1324,175 @@ def privacy_request(db: Session, policy: Policy) -> PrivacyRequest: privacy_request.delete(db) +@pytest.fixture(scope="function") +def request_task(db: Session, privacy_request) -> RequestTask: + root_task = RequestTask.create( + db, + data={ + "action_type": ActionType.access, + "status": "complete", + "privacy_request_id": privacy_request.id, + "collection_address": "__ROOT__:__ROOT__", + "dataset_name": "__ROOT__", + "collection_name": "__ROOT__", + "upstream_tasks": [], + "downstream_tasks": ["test_dataset:test_collection"], + "all_descendant_tasks": [ + "test_dataset:test_collection", + "__TERMINATE__:__TERMINATE__", + ], + }, + ) + request_task = RequestTask.create( + db, + data={ + "action_type": ActionType.access, + "status": "pending", + "privacy_request_id": privacy_request.id, + "collection_address": "test_dataset:test_collection", + "dataset_name": "test_dataset", + "collection_name": "test_collection", + "upstream_tasks": ["__ROOT__:__ROOT__"], + "downstream_tasks": ["__TERMINATE__:__TERMINATE__"], + "all_descendant_tasks": ["__TERMINATE__:__TERMINATE__"], + }, + ) + end_task = RequestTask.create( + db, + data={ + "action_type": ActionType.access, + "status": "pending", + "privacy_request_id": privacy_request.id, + "collection_address": "__TERMINATE__:__TERMINATE__", + "dataset_name": "__TERMINATE__", + "collection_name": "__TERMINATE__", + "upstream_tasks": ["test_dataset:test_collection"], + "downstream_tasks": [], + "all_descendant_tasks": [], + }, + ) + yield request_task + + try: + end_task.delete(db).delete(db) + except ObjectDeletedError: + pass + try: + request_task.delete(db) + except ObjectDeletedError: + pass + try: + root_task.delete(db) + except ObjectDeletedError: + pass + + +@pytest.fixture(scope="function") +def erasure_request_task(db: Session, privacy_request) -> RequestTask: + root_task = RequestTask.create( + db, + data={ + "action_type": ActionType.erasure, + "status": "complete", + "privacy_request_id": privacy_request.id, + "collection_address": "__ROOT__:__ROOT__", + "dataset_name": "__ROOT__", + "collection_name": "__ROOT__", + "upstream_tasks": [], + "downstream_tasks": ["test_dataset:test_collection"], + "all_descendant_tasks": [ + "test_dataset:test_collection", + "__TERMINATE__:__TERMINATE__", + ], + }, + ) + request_task = RequestTask.create( + db, + data={ + "action_type": ActionType.erasure, + "status": "pending", + "privacy_request_id": privacy_request.id, + "collection_address": "test_dataset:test_collection", + "dataset_name": "test_dataset", + "collection_name": "test_collection", + "upstream_tasks": ["__ROOT__:__ROOT__"], + "downstream_tasks": ["__TERMINATE__:__TERMINATE__"], + "all_descendant_tasks": ["__TERMINATE__:__TERMINATE__"], + }, + ) + end_task = RequestTask.create( + db, + data={ + "action_type": ActionType.erasure, + "status": "pending", + "privacy_request_id": privacy_request.id, + "collection_address": "__TERMINATE__:__TERMINATE__", + "dataset_name": "__TERMINATE__", + "collection_name": "__TERMINATE__", + "upstream_tasks": ["test_dataset:test_collection"], + "downstream_tasks": [], + "all_descendant_tasks": [], + }, + ) + yield request_task + end_task.delete(db) + request_task.delete(db) + root_task.delete(db) + + +@pytest.fixture(scope="function") +def consent_request_task(db: Session, privacy_request) -> RequestTask: + root_task = RequestTask.create( + db, + data={ + "action_type": ActionType.consent, + "status": "complete", + "privacy_request_id": privacy_request.id, + "collection_address": "__ROOT__:__ROOT__", + "dataset_name": "__ROOT__", + "collection_name": "__ROOT__", + "upstream_tasks": [], + "downstream_tasks": ["test_dataset:test_collection"], + "all_descendant_tasks": [ + "test_dataset:test_collection", + "__TERMINATE__:__TERMINATE__", + ], + }, + ) + request_task = RequestTask.create( + db, + data={ + "action_type": ActionType.consent, + "status": "pending", + "privacy_request_id": privacy_request.id, + "collection_address": "test_dataset:test_collection", + "dataset_name": "test_dataset", + "collection_name": "test_collection", + "upstream_tasks": ["__ROOT__:__ROOT__"], + "downstream_tasks": ["__TERMINATE__:__TERMINATE__"], + "all_descendant_tasks": ["__TERMINATE__:__TERMINATE__"], + }, + ) + end_task = RequestTask.create( + db, + data={ + "action_type": ActionType.consent, + "status": "pending", + "privacy_request_id": privacy_request.id, + "collection_address": "__TERMINATE__:__TERMINATE__", + "dataset_name": "__TERMINATE__", + "collection_name": "__TERMINATE__", + "upstream_tasks": ["test_dataset:test_collection"], + "downstream_tasks": [], + "all_descendant_tasks": [], + }, + ) + yield request_task + end_task.delete(db) + request_task.delete(db) + root_task.delete(db) + + @pytest.fixture(scope="function") def privacy_request_with_erasure_policy( db: Session, erasure_policy: Policy @@ -3036,3 +3209,39 @@ def served_notice_history( ) yield pref_1 pref_1.delete(db) + + +@pytest.fixture(scope="function") +def use_dsr_3_0(): + original_value: int = CONFIG.execution.use_dsr_3_0 + CONFIG.execution.use_dsr_3_0 = True + yield CONFIG + CONFIG.execution.use_dsr_3_0 = original_value + + +@pytest.fixture(scope="function") +def use_dsr_2_0(): + original_value: int = CONFIG.execution.use_dsr_3_0 + CONFIG.execution.use_dsr_3_0 = False + yield CONFIG + CONFIG.execution.use_dsr_3_0 = original_value + + +@pytest.fixture() +def postgres_dataset_graph(example_datasets, connection_config): + dataset_postgres = Dataset(**example_datasets[0]) + graph = convert_dataset_to_graph(dataset_postgres, connection_config.key) + + dataset_graph = DatasetGraph(*[graph]) + return dataset_graph + + +@pytest.fixture() +def postgres_and_mongo_dataset_graph( + example_datasets, connection_config, mongo_connection_config +): + dataset_postgres = Dataset(**example_datasets[0]) + graph = convert_dataset_to_graph(dataset_postgres, connection_config.key) + dataset_mongo = Dataset(**example_datasets[1]) + mongo_graph = convert_dataset_to_graph(dataset_mongo, mongo_connection_config.key) + return DatasetGraph(*[graph, mongo_graph]) diff --git a/tests/fixtures/integration_fixtures.py b/tests/fixtures/integration_fixtures.py index 2142eb4576..3807f1bc74 100644 --- a/tests/fixtures/integration_fixtures.py +++ b/tests/fixtures/integration_fixtures.py @@ -104,14 +104,19 @@ def generate_integration_records(): @pytest.fixture(scope="function") -def integration_postgres_config(postgres_inserts) -> ConnectionConfig: - return ConnectionConfig( - name="postgres_test", - key="postgres_example", - connection_type=ConnectionType.postgres, - access=AccessLevel.write, - secrets=integration_secrets["postgres_example"], +def integration_postgres_config(postgres_inserts, db) -> ConnectionConfig: + connection_config = ConnectionConfig.create( + db=db, + data={ + "key": "postgres_example", + "name": "postgres_test", + "connection_type": ConnectionType.postgres, + "access": AccessLevel.write, + "secrets": integration_secrets["postgres_example"], + }, ) + yield connection_config + connection_config.delete(db) def sql_insert(engine: Engine, table_name: str, record: Dict[str, Any]) -> None: diff --git a/tests/fixtures/mongodb_fixtures.py b/tests/fixtures/mongodb_fixtures.py index 976bbee7d2..be35d09479 100644 --- a/tests/fixtures/mongodb_fixtures.py +++ b/tests/fixtures/mongodb_fixtures.py @@ -1,4 +1,4 @@ -from typing import Generator +from typing import Dict, Generator, List from uuid import uuid4 import pytest @@ -9,12 +9,14 @@ ConnectionConfig, ConnectionType, ) +from fides.api.models.datasetconfig import DatasetConfig from fides.api.models.policy import ActionType from fides.api.models.privacy_request import ( ExecutionLog, ExecutionLogStatus, PrivacyRequest, ) +from fides.api.models.sql_models import Dataset as CtlDataset from .application_fixtures import integration_secrets @@ -40,6 +42,33 @@ def mongo_connection_config(db: Session) -> Generator: connection_config.delete(db) +@pytest.fixture +def mongo_dataset_config( + mongo_connection_config: ConnectionConfig, + db: Session, + example_datasets: List[Dict], +) -> Generator: + mongo_dataset = example_datasets[1] + fides_key = mongo_dataset["fides_key"] + mongo_connection_config.name = fides_key + mongo_connection_config.key = fides_key + mongo_connection_config.save(db=db) + + ctl_dataset = CtlDataset.create_from_dataset_dict(db, mongo_dataset) + + dataset = DatasetConfig.create( + db=db, + data={ + "connection_config_id": mongo_connection_config.id, + "fides_key": fides_key, + "ctl_dataset_id": ctl_dataset.id, + }, + ) + yield dataset + dataset.delete(db=db) + ctl_dataset.delete(db=db) + + @pytest.fixture(scope="function") def mongo_execution_log( db: Session, diff --git a/tests/fixtures/saas/sentry_fixtures.py b/tests/fixtures/saas/sentry_fixtures.py index 247306c804..5a098a1d02 100644 --- a/tests/fixtures/saas/sentry_fixtures.py +++ b/tests/fixtures/saas/sentry_fixtures.py @@ -48,7 +48,7 @@ def sentry_identity_email(saas_config): @pytest.fixture def sentry_config() -> Dict[str, Any]: return load_config_with_replacement( - "data/saas/config/sentry_config.yml", "", "sentry_instance" + "data/saas/config/sentry_config.yml", "", "sentry_dataset" ) @@ -77,6 +77,24 @@ def sentry_connection_config(db: session, sentry_config, sentry_secrets) -> Gene connection_config.delete(db) +@pytest.fixture(scope="function") +def sentry_connection_config_without_secrets(db: session, sentry_config) -> Generator: + fides_key = sentry_config["fides_key"] + connection_config = ConnectionConfig.create( + db=db, + data={ + "key": fides_key, + "name": fides_key, + "connection_type": ConnectionType.saas, + "access": AccessLevel.write, + "secrets": {}, + "saas_config": sentry_config, + }, + ) + yield connection_config + connection_config.delete(db) + + @pytest.fixture def sentry_dataset_config( db: Session, @@ -101,3 +119,29 @@ def sentry_dataset_config( yield dataset dataset.delete(db=db) ctl_dataset.delete(db=db) + + +@pytest.fixture +def sentry_dataset_config_without_secrets( + db: Session, + sentry_connection_config_without_secrets: ConnectionConfig, + sentry_dataset: Dict[str, Any], +) -> Generator: + fides_key = sentry_dataset["fides_key"] + sentry_connection_config_without_secrets.name = fides_key + sentry_connection_config_without_secrets.key = fides_key + sentry_connection_config_without_secrets.save(db=db) + + ctl_dataset = CtlDataset.create_from_dataset_dict(db, sentry_dataset) + + dataset = DatasetConfig.create( + db=db, + data={ + "connection_config_id": sentry_connection_config_without_secrets.id, + "fides_key": fides_key, + "ctl_dataset_id": ctl_dataset.id, + }, + ) + yield dataset + dataset.delete(db=db) + ctl_dataset.delete(db=db) diff --git a/tests/fixtures/saas/shopify_fixtures.py b/tests/fixtures/saas/shopify_fixtures.py index 030b9f8a0f..835685b133 100644 --- a/tests/fixtures/saas/shopify_fixtures.py +++ b/tests/fixtures/saas/shopify_fixtures.py @@ -148,7 +148,8 @@ def shopify_erasure_data( url=f"{base_url}/admin/api/2022-07/customers.json", json=body, headers=headers ) customer = customers_response.json() - assert customers_response.ok + # not asserting that customer response is okay for running back to back requests for DSR 2.0 and DSR 3.0 + # which can cause a 422 - that the email is already taken sleep(30) diff --git a/tests/fixtures/saas/test_data/planet_express/planet_express_functions.py b/tests/fixtures/saas/test_data/planet_express/planet_express_functions.py index 0ae242216e..d15de08a1a 100644 --- a/tests/fixtures/saas/test_data/planet_express/planet_express_functions.py +++ b/tests/fixtures/saas/test_data/planet_express/planet_express_functions.py @@ -2,7 +2,7 @@ from requests import PreparedRequest -from fides.api.graph.traversal import TraversalNode +from fides.api.graph.execution import ExecutionNode from fides.api.models.connectionconfig import ConnectionConfig from fides.api.models.policy import Policy from fides.api.models.privacy_request import PrivacyRequest @@ -22,7 +22,7 @@ @register("planet_express_user_access", [SaaSRequestType.READ]) def planet_express_user_access( client: AuthenticatedClient, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, input_data: Dict[str, List[Any]], diff --git a/tests/ops/api/v1/endpoints/test_drp_endpoints.py b/tests/ops/api/v1/endpoints/test_drp_endpoints.py index d6544ffcf9..b47fe1486f 100644 --- a/tests/ops/api/v1/endpoints/test_drp_endpoints.py +++ b/tests/ops/api/v1/endpoints/test_drp_endpoints.py @@ -16,7 +16,11 @@ PrivacyRequestStatus, ) from fides.api.schemas.privacy_request import PrivacyRequestDRPStatus -from fides.api.util.cache import get_drp_request_body_cache_key, get_identity_cache_key +from fides.api.util.cache import ( + cache_task_tracking_key, + get_drp_request_body_cache_key, + get_identity_cache_key, +) from fides.common.api.scope_registry import ( POLICY_READ, PRIVACY_REQUEST_READ, @@ -558,8 +562,15 @@ def test_revoke_wrong_status( assert privacy_request.status == PrivacyRequestStatus.in_processing assert privacy_request.canceled_at is None + @mock.patch("fides.api.models.privacy_request.celery_app.control.revoke") def test_revoke( - self, db, api_client: TestClient, generate_auth_header, url, privacy_request + self, + revoke_task_mock, + db, + api_client: TestClient, + generate_auth_header, + url, + privacy_request, ): privacy_request.status = PrivacyRequestStatus.pending privacy_request.save(db) @@ -582,3 +593,55 @@ def test_revoke( assert data["request_id"] == privacy_request.id assert data["status"] == "revoked" assert data["reason"] == canceled_reason + + assert ( + not revoke_task_mock.called + ), "No celery task cached, so we don't attempt to revoke" + + @mock.patch("fides.api.models.privacy_request.celery_app.control.revoke") + def test_revoke_with_request_tasks( + self, + revoke_task_mock, + db, + api_client: TestClient, + generate_auth_header, + url, + privacy_request, + request_task, + ): + """Generally you can only revoke pending Privacy Requests, but model level + logic does have the beginnings to try to revoke celery tasks""" + + privacy_request.status = PrivacyRequestStatus.pending + privacy_request.save(db) + canceled_reason = "Accidentally submitted" + + cache_task_tracking_key( + privacy_request.id, "mock_celery_task_id_for_privacy_request" + ) + cache_task_tracking_key(request_task.id, "mock_celery_task_id_for_request_task") + + auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_REVIEW]) + response = api_client.post( + url, + headers=auth_header, + json={"request_id": privacy_request.id, "reason": canceled_reason}, + ) + assert 200 == response.status_code + db.refresh(privacy_request) + + assert privacy_request.status == PrivacyRequestStatus.canceled + + assert revoke_task_mock.called + assert revoke_task_mock._mock_call_count == 2 + + # Revokes privacy request and request task celery task + assert { + revoke_task_mock._mock_call_args_list[0][0][0], + revoke_task_mock._mock_call_args_list[1][0][0], + } == { + "mock_celery_task_id_for_request_task", + "mock_celery_task_id_for_privacy_request", + } + + revoke_task_mock._mock_call_args_list[0][1] == {"terminate": False} diff --git a/tests/ops/api/v1/endpoints/test_pre_approval_webhook_endpoints.py b/tests/ops/api/v1/endpoints/test_pre_approval_webhook_endpoints.py index 2a5aa918e2..f4be6b2483 100644 --- a/tests/ops/api/v1/endpoints/test_pre_approval_webhook_endpoints.py +++ b/tests/ops/api/v1/endpoints/test_pre_approval_webhook_endpoints.py @@ -370,7 +370,7 @@ def test_patch_pre_approval_webhook_invalid_webhook_key( assert resp.status_code == 404 def test_patch_pre_approval_webhook_nonexistent_connection_config_key( - self, api_client, url, generate_auth_header + self, api_client, url, generate_auth_header ): request_body = {"connection_config_key": "nonexistent_key"} auth_header = generate_auth_header(scopes=[WEBHOOK_CREATE_OR_UPDATE]) diff --git a/tests/ops/api/v1/endpoints/test_privacy_request_endpoints.py b/tests/ops/api/v1/endpoints/test_privacy_request_endpoints.py index 09d724857a..e9fa2c2286 100644 --- a/tests/ops/api/v1/endpoints/test_privacy_request_endpoints.py +++ b/tests/ops/api/v1/endpoints/test_privacy_request_endpoints.py @@ -22,7 +22,6 @@ JWE_ISSUED_AT, JWE_PAYLOAD_CLIENT_ID, JWE_PAYLOAD_ROLES, - JWE_PAYLOAD_SCOPES, ) from fides.api.graph.config import CollectionAddress from fides.api.graph.graph import DatasetGraph @@ -35,7 +34,6 @@ from fides.api.models.privacy_request import ( ExecutionLog, ExecutionLogStatus, - ManualAction, PrivacyRequest, PrivacyRequestError, PrivacyRequestNotifications, @@ -54,13 +52,9 @@ ) from fides.api.schemas.policy import ActionType, PolicyResponse from fides.api.schemas.redis_cache import Identity, LabeledIdentity -from fides.api.task import graph_task +from fides.api.task.graph_runners import access_runner from fides.api.tasks import MESSAGING_QUEUE_NAME -from fides.api.util.cache import ( - get_encryption_cache_key, - get_identity_cache_key, - get_masking_secret_cache_key, -) +from fides.api.util.cache import get_encryption_cache_key, get_masking_secret_cache_key from fides.common.api.scope_registry import ( DATASET_CREATE_OR_UPDATE, PRIVACY_REQUEST_CALLBACK_RESUME, @@ -80,11 +74,10 @@ PRIVACY_REQUEST_AUTHENTICATED, PRIVACY_REQUEST_BULK_RETRY, PRIVACY_REQUEST_DENY, - PRIVACY_REQUEST_MANUAL_ERASURE, - PRIVACY_REQUEST_MANUAL_INPUT, PRIVACY_REQUEST_MANUAL_WEBHOOK_ACCESS_INPUT, PRIVACY_REQUEST_MANUAL_WEBHOOK_ERASURE_INPUT, PRIVACY_REQUEST_NOTIFICATIONS, + PRIVACY_REQUEST_REQUEUE, PRIVACY_REQUEST_RESUME, PRIVACY_REQUEST_RESUME_FROM_REQUIRES_INPUT, PRIVACY_REQUEST_RETRY, @@ -92,6 +85,7 @@ PRIVACY_REQUEST_VERIFY_IDENTITY, PRIVACY_REQUESTS, REQUEST_PREVIEW, + REQUEST_TASKS, V1_URL_PREFIX, ) from fides.config import CONFIG @@ -387,7 +381,7 @@ def test_create_privacy_request_with_masking_configuration( assert run_access_request_mock.called @mock.patch( - "fides.api.service.privacy_request.request_runner_service.run_access_request" + "fides.api.service.privacy_request.request_runner_service.access_runner" ) def test_create_privacy_request_limit_exceeded( self, @@ -1536,47 +1530,6 @@ def test_get_privacy_requests_csv_format( privacy_request.delete(db) - def test_get_paused_access_privacy_request_resume_info( - self, db, privacy_request, generate_auth_header, api_client, url - ): - # Mock the privacy request being in a paused state waiting for manual input to the "manual_collection" - privacy_request.status = PrivacyRequestStatus.paused - privacy_request.save(db) - paused_step = CurrentStep.access - paused_collection = CollectionAddress("manual_dataset", "manual_collection") - privacy_request.cache_paused_collection_details( - step=paused_step, - collection=paused_collection, - action_needed=[ - ManualAction( - locators={"email": ["customer-1@example.com"]}, - get=["authorized_user"], - update=None, - ) - ], - ) - - auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_READ]) - response = api_client.get(url, headers=auth_header) - assert 200 == response.status_code - - data = response.json()["items"][0] - assert data["status"] == "paused" - assert data["action_required_details"] == { - "step": "access", - "collection": "manual_dataset:manual_collection", - "action_needed": [ - { - "locators": {"email": ["customer-1@example.com"]}, - "get": ["authorized_user"], - "update": None, - } - ], - } - assert data["resume_endpoint"] == "/privacy-request/{}/manual_input".format( - privacy_request.id - ) - def test_get_requires_input_privacy_request_resume_info( self, db, privacy_request, generate_auth_header, api_client, url ): @@ -1595,47 +1548,6 @@ def test_get_requires_input_privacy_request_resume_info( "resume_endpoint" ] == "/privacy-request/{}/resume_from_requires_input".format(privacy_request.id) - def test_get_paused_erasure_privacy_request_resume_info( - self, db, privacy_request, generate_auth_header, api_client, url - ): - # Mock the privacy request being in a paused state waiting for manual erasure confirmation to the "another_collection" - privacy_request.status = PrivacyRequestStatus.paused - privacy_request.save(db) - paused_step = CurrentStep.erasure - paused_collection = CollectionAddress("manual_dataset", "another_collection") - privacy_request.cache_paused_collection_details( - step=paused_step, - collection=paused_collection, - action_needed=[ - ManualAction( - locators={"id": [32424]}, - get=None, - update={"authorized_user": "abcde_masked_user"}, - ) - ], - ) - - auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_READ]) - response = api_client.get(url, headers=auth_header) - assert 200 == response.status_code - - data = response.json()["items"][0] - assert data["status"] == "paused" - assert data["action_required_details"] == { - "step": "erasure", - "collection": "manual_dataset:another_collection", - "action_needed": [ - { - "locators": {"id": [32424]}, - "get": None, - "update": {"authorized_user": "abcde_masked_user"}, - } - ], - } - assert data["resume_endpoint"] == "/privacy-request/{}/erasure_confirm".format( - privacy_request.id - ) - def test_get_paused_webhook_resume_info( self, db, privacy_request, generate_auth_header, api_client, url ): @@ -1653,7 +1565,7 @@ def test_get_paused_webhook_resume_info( privacy_request.id ) - def test_get_failed_request_resume_info_from_collection( + def test_get_failed_request_resume_info( self, db, privacy_request, generate_auth_header, api_client, url ): # Mock the privacy request being in an errored state waiting for retry @@ -1661,7 +1573,6 @@ def test_get_failed_request_resume_info_from_collection( privacy_request.save(db) privacy_request.cache_failed_checkpoint_details( step=CurrentStep.erasure, - collection=CollectionAddress("manual_example", "another_collection"), ) auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_READ]) @@ -1672,12 +1583,12 @@ def test_get_failed_request_resume_info_from_collection( assert data["status"] == "error" assert data["action_required_details"] == { "step": "erasure", - "collection": "manual_example:another_collection", + "collection": None, "action_needed": None, } assert data["resume_endpoint"] == f"/privacy-request/{privacy_request.id}/retry" - def test_get_failed_request_resume_info_from_email_send( + def test_get_failed_request_resume_info_from_email_post_send( self, db, privacy_request, generate_auth_header, api_client, url ): # Mock the privacy request being in an errored state waiting for retry @@ -1685,7 +1596,6 @@ def test_get_failed_request_resume_info_from_email_send( privacy_request.save(db) privacy_request.cache_failed_checkpoint_details( step=CurrentStep.email_post_send, - collection=None, ) auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_READ]) @@ -1987,9 +1897,9 @@ def test_request_preview_incorrect_body( def test_request_preview_all( self, dataset_config_preview, - manual_dataset_config, - integration_manual_config, + mongo_connection_config, postgres_example_test_dataset_config, + mongo_dataset_config, api_client: TestClient, url, generate_auth_header, @@ -2008,24 +1918,15 @@ def test_request_preview_all( ) == "SELECT email,id FROM subscriptions WHERE email = ?" ) - - assert next( - response["query"] - for response in response_body - if response["collectionAddress"]["dataset"] == "manual_input" - if response["collectionAddress"]["collection"] == "filing_cabinet" - ) == { - "locators": {"customer_id": ["?", "?"]}, - "get": ["authorized_user", "customer_id", "id", "payment_card_id"], - "update": None, - } - - assert next( - response["query"] - for response in response_body - if response["collectionAddress"]["dataset"] == "manual_input" - if response["collectionAddress"]["collection"] == "storage_unit" - ) == {"locators": {"email": ["?"]}, "get": ["box_id", "email"], "update": None} + assert ( + next( + response["query"] + for response in response_body + if response["collectionAddress"]["dataset"] == "mongo_test" + if response["collectionAddress"]["collection"] == "customer_feedback" + ) + == "db.mongo_test.customer_feedback.find({'customer_information.email': ?}, {'_id': 1, 'customer_information': 1, 'date': 1, 'message': 1, 'rating': 1})" + ) class TestApprovePrivacyRequest: @@ -2710,146 +2611,6 @@ def test_resume_privacy_request( privacy_request.delete(db) -class TestResumeAccessRequestWithManualInput: - @pytest.fixture(scope="function") - def url(self, privacy_request): - return V1_URL_PREFIX + PRIVACY_REQUEST_MANUAL_INPUT.format( - privacy_request_id=privacy_request.id - ) - - def test_manual_resume_not_authenticated(self, api_client, url): - response = api_client.post(url, headers={}, json={}) - assert response.status_code == 401 - - def test_manual_resume_wrong_scope(self, api_client, url, generate_auth_header): - auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_READ]) - - response = api_client.post(url, headers=auth_header, json={}) - assert response.status_code == 403 - - def test_manual_resume_privacy_request_not_paused( - self, api_client, url, generate_auth_header, privacy_request - ): - auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) - response = api_client.post(url, headers=auth_header, json=[{"mock": "row"}]) - assert response.status_code == 400 - assert ( - response.json()["detail"] - == f"Invalid resume request: privacy request '{privacy_request.id}' status = in_processing. Privacy request is not paused." - ) - - def test_manual_resume_privacy_request_no_paused_location( - self, db, api_client, url, generate_auth_header, privacy_request - ): - auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) - privacy_request.status = PrivacyRequestStatus.paused - privacy_request.save(db) - - response = api_client.post(url, headers=auth_header, json=[{"mock": "row"}]) - assert response.status_code == 400 - assert ( - response.json()["detail"] - == f"Cannot resume privacy request '{privacy_request.id}'; no paused details." - ) - - privacy_request.delete(db) - - def test_resume_with_manual_input_collection_has_changed( - self, db, api_client, url, generate_auth_header, privacy_request - ): - """Fail if user has changed graph so that the paused node doesn't exist""" - auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) - privacy_request.status = PrivacyRequestStatus.paused - privacy_request.save(db) - - privacy_request.cache_paused_collection_details( - step=CurrentStep.access, - collection=CollectionAddress("manual_example", "filing_cabinet"), - ) - - response = api_client.post(url, headers=auth_header, json=[{"mock": "row"}]) - assert response.status_code == 422 - assert ( - response.json()["detail"] - == "Cannot save manual data. No collection in graph with name: 'manual_example:filing_cabinet'." - ) - - privacy_request.delete(db) - - @pytest.mark.usefixtures( - "postgres_example_test_dataset_config", "manual_dataset_config" - ) - def test_resume_with_manual_input_invalid_data( - self, - db, - api_client, - url, - generate_auth_header, - privacy_request, - ): - """Fail if the manual data entered does not match fields on the dataset""" - auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) - privacy_request.status = PrivacyRequestStatus.paused - privacy_request.save(db) - - privacy_request.cache_paused_collection_details( - step=CurrentStep.access, - collection=CollectionAddress("manual_input", "filing_cabinet"), - ) - - response = api_client.post(url, headers=auth_header, json=[{"mock": "row"}]) - assert response.status_code == 422 - assert ( - response.json()["detail"] - == "Cannot save manual rows. No 'mock' field defined on the 'manual_input:filing_cabinet' collection." - ) - - privacy_request.delete(db) - - @mock.patch( - "fides.api.service.privacy_request.request_runner_service.run_privacy_request.delay" - ) - @pytest.mark.usefixtures( - "postgres_example_test_dataset_config", "manual_dataset_config" - ) - def test_resume_with_manual_input( - self, - _, - db, - api_client, - url, - generate_auth_header, - privacy_request, - ): - auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) - privacy_request.status = PrivacyRequestStatus.paused - privacy_request.save(db) - - privacy_request.cache_paused_collection_details( - step=CurrentStep.access, - collection=CollectionAddress("manual_input", "filing_cabinet"), - ) - - response = api_client.post( - url, - headers=auth_header, - json=[ - { - "id": 1, - "authorized_user": "Jason Doe", - "customer_id": 1, - "payment_card_id": "abcde", - } - ], - ) - assert response.status_code == 200 - - db.refresh(privacy_request) - assert privacy_request.status == PrivacyRequestStatus.in_processing - - privacy_request.delete(db) - - class TestValidateManualInput: """Verify pytest cell-var-from-loop warning is a false positive""" @@ -2896,130 +2657,6 @@ def test_field_on_second_row_does_not_match(self, dataset_graph): ) -class TestResumeErasureRequestWithManualConfirmation: - @pytest.fixture(scope="function") - def url(self, privacy_request): - return V1_URL_PREFIX + PRIVACY_REQUEST_MANUAL_ERASURE.format( - privacy_request_id=privacy_request.id - ) - - def test_manual_resume_not_authenticated(self, api_client, url): - response = api_client.post(url, headers={}, json={}) - assert response.status_code == 401 - - def test_manual_resume_wrong_scope(self, api_client, url, generate_auth_header): - auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_READ]) - - response = api_client.post(url, headers=auth_header, json={}) - assert response.status_code == 403 - - def test_manual_resume_privacy_request_not_paused( - self, api_client, url, generate_auth_header, privacy_request - ): - auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) - response = api_client.post(url, headers=auth_header, json={"row_count": 0}) - assert response.status_code == 400 - assert ( - response.json()["detail"] - == f"Invalid resume request: privacy request '{privacy_request.id}' status = in_processing. Privacy request is not paused." - ) - - def test_manual_resume_privacy_request_no_paused_location( - self, db, api_client, url, generate_auth_header, privacy_request - ): - auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) - privacy_request.status = PrivacyRequestStatus.paused - privacy_request.save(db) - - response = api_client.post(url, headers=auth_header, json={"row_count": 0}) - assert response.status_code == 400 - assert ( - response.json()["detail"] - == f"Cannot resume privacy request '{privacy_request.id}'; no paused details." - ) - - privacy_request.delete(db) - - def test_resume_with_manual_erasure_confirmation_collection_has_changed( - self, db, api_client, url, generate_auth_header, privacy_request - ): - """Fail if user has changed graph so that the paused node doesn't exist""" - auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) - privacy_request.status = PrivacyRequestStatus.paused - privacy_request.save(db) - - privacy_request.cache_paused_collection_details( - step=CurrentStep.erasure, - collection=CollectionAddress("manual_example", "filing_cabinet"), - ) - - response = api_client.post(url, headers=auth_header, json={"row_count": 0}) - assert response.status_code == 422 - assert ( - response.json()["detail"] - == "Cannot save manual data. No collection in graph with name: 'manual_example:filing_cabinet'." - ) - - privacy_request.delete(db) - - def test_resume_still_paused_at_access_request( - self, db, api_client, url, generate_auth_header, privacy_request - ): - """Fail if user hitting wrong endpoint to resume.""" - auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) - privacy_request.status = PrivacyRequestStatus.paused - privacy_request.save(db) - - privacy_request.cache_paused_collection_details( - step=CurrentStep.access, - collection=CollectionAddress("manual_example", "filing_cabinet"), - ) - response = api_client.post(url, headers=auth_header, json={"row_count": 0}) - assert response.status_code == 400 - - assert ( - response.json()["detail"] - == "Collection 'manual_example:filing_cabinet' is paused at the access step. Pass in manual data instead to '/privacy-request/{privacy_request_id}/manual_input' to resume." - ) - - privacy_request.delete(db) - - @pytest.mark.usefixtures( - "postgres_example_test_dataset_config", "manual_dataset_config" - ) - @mock.patch( - "fides.api.service.privacy_request.request_runner_service.run_privacy_request.delay" - ) - def test_resume_with_manual_count( - self, - _, - db, - api_client, - url, - generate_auth_header, - privacy_request, - ): - auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) - privacy_request.status = PrivacyRequestStatus.paused - privacy_request.save(db) - - privacy_request.cache_paused_collection_details( - step=CurrentStep.erasure, - collection=CollectionAddress("manual_input", "filing_cabinet"), - ) - response = api_client.post( - url, - headers=auth_header, - json={"row_count": 5}, - ) - assert response.status_code == 200 - - db.refresh(privacy_request) - assert privacy_request.status == PrivacyRequestStatus.in_processing - - privacy_request.delete(db) - - class TestBulkRestartFromFailure: @pytest.fixture(scope="function") def url(self): @@ -3092,6 +2729,8 @@ def test_restart_from_failure_no_stopped_step( def test_restart_from_failure_from_specific_collection( self, submit_mock, api_client, url, generate_auth_header, db, privacy_requests ): + """Collection is no longer a relevant parameter here, but its inclusion does not affect + restarting the privacy request""" auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) data = [privacy_requests[0].id] privacy_requests[0].status = PrivacyRequestStatus.error @@ -3099,7 +2738,6 @@ def test_restart_from_failure_from_specific_collection( privacy_requests[0].cache_failed_checkpoint_details( step=CurrentStep.access, - collection=CollectionAddress("test_dataset", "test_collection"), ) response = api_client.post(url, json=data, headers=auth_header) @@ -3132,7 +2770,6 @@ def test_restart_from_failure_outside_graph( privacy_requests[0].cache_failed_checkpoint_details( step=CurrentStep.email_post_send, - collection=None, ) response = api_client.post(url, json=data, headers=auth_header) @@ -3166,7 +2803,6 @@ def test_mixed_result( privacy_requests[0].cache_failed_checkpoint_details( step=CurrentStep.access, - collection=CollectionAddress("test_dataset", "test_collection"), ) response = api_client.post(url, json=data, headers=auth_header) @@ -3248,7 +2884,7 @@ def test_restart_from_failure_no_stopped_step( @mock.patch( "fides.api.service.privacy_request.request_runner_service.run_privacy_request.delay" ) - def test_restart_from_failure_from_specific_collection( + def test_restart_from_failure_from_access_step( self, submit_mock, api_client, url, generate_auth_header, db, privacy_request ): auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) @@ -3257,7 +2893,6 @@ def test_restart_from_failure_from_specific_collection( privacy_request.cache_failed_checkpoint_details( step=CurrentStep.access, - collection=CollectionAddress("test_dataset", "test_collection"), ) response = api_client.post(url, headers=auth_header) @@ -3275,7 +2910,7 @@ def test_restart_from_failure_from_specific_collection( @mock.patch( "fides.api.service.privacy_request.request_runner_service.run_privacy_request.delay" ) - def test_restart_from_failure_outside_graph( + def test_restart_from_email_post_send( self, submit_mock, api_client, url, generate_auth_header, db, privacy_request ): auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) @@ -3284,7 +2919,6 @@ def test_restart_from_failure_outside_graph( privacy_request.cache_failed_checkpoint_details( step=CurrentStep.email_post_send, - collection=None, ) response = api_client.post(url, headers=auth_header) @@ -4788,7 +4422,7 @@ def test_create_privacy_request_with_masking_configuration( assert len(response_data) == 1 @mock.patch( - "fides.api.service.privacy_request.request_runner_service.run_access_request" + "fides.api.service.privacy_request.request_runner_service.access_runner" ) def test_create_privacy_request_limit_exceeded( self, @@ -5086,7 +4720,7 @@ async def test_privacy_request_data_transfer( # execute the privacy request to mimic the expected workflow on the "child" # this will populate the access results in the cache, which is required for the # transfer endpoint to work - await graph_task.run_access_request( + access_runner( privacy_request, policy, graph, @@ -5301,3 +4935,223 @@ def test_update_privacy_request_notification_info( response = api_client.put(url, json=data, headers=auth_header) assert response.status_code == 200 assert response.json() == data + + +class TestPrivacyRequestTasksList: + @pytest.fixture(scope="function") + def url(self, privacy_request) -> str: + return V1_URL_PREFIX + REQUEST_TASKS.format( + privacy_request_id=privacy_request.id + ) + + def test_get_request_tasks_unauthenticated(self, api_client: TestClient, url): + response = api_client.get(url, headers={}) + assert 401 == response.status_code + + def test_get_request_tasks_wrong_scope( + self, api_client: TestClient, generate_auth_header, url + ): + auth_header = generate_auth_header(scopes=[STORAGE_CREATE_OR_UPDATE]) + response = api_client.get(url, headers=auth_header) + assert 403 == response.status_code + + def test_no_tasks(self, api_client, generate_auth_header, url): + auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_READ]) + response = api_client.get(url, headers=auth_header) + assert 200 == response.status_code + + assert response.json() == [] + + def test_get_request_tasks( + self, api_client: TestClient, generate_auth_header, url, request_task + ): + auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_READ]) + response = api_client.get(url, headers=auth_header) + assert 200 == response.status_code + assert len(response.json()) == 3 + resp = response.json() + root_response = resp[0] + assert root_response["collection_address"] == "__ROOT__:__ROOT__" + assert resp[1]["collection_address"] == "test_dataset:test_collection" + assert resp[2]["collection_address"] == "__TERMINATE__:__TERMINATE__" + + assert root_response["upstream_tasks"] == [] + assert root_response["downstream_tasks"] == ["test_dataset:test_collection"] + assert root_response["status"] == "complete" + assert root_response["action_type"] == "access" + + # No DSR data is returned in the response + assert set(root_response.keys()) == { + "id", + "collection_address", + "status", + "created_at", + "updated_at", + "upstream_tasks", + "downstream_tasks", + "action_type", + } + + +class TestRequeuePrivacyRequest: + @pytest.fixture(scope="function") + def url(self, privacy_request, request_task) -> str: + return V1_URL_PREFIX + PRIVACY_REQUEST_REQUEUE.format( + privacy_request_id=privacy_request.id + ) + + def test_requeue_privacy_request_unauthenticated(self, api_client: TestClient, url): + response = api_client.post(url, headers={}) + assert 401 == response.status_code + + def test_requeue_privacy_request_wrong_scope( + self, api_client: TestClient, generate_auth_header, url + ): + auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_READ]) + response = api_client.post(url, headers=auth_header) + assert 403 == response.status_code + + def test_requeue_privacy_request_privacy_request_not_found( + self, api_client: TestClient, generate_auth_header, url, request_task + ): + url = V1_URL_PREFIX + PRIVACY_REQUEST_REQUEUE.format( + privacy_request_id="adsf", task_id=request_task.id + ) + + auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) + response = api_client.post(url, headers=auth_header) + assert 404 == response.status_code + + def test_requeue_privacy_request_already_completed( + self, + db, + api_client: TestClient, + generate_auth_header, + url, + privacy_request, + request_task, + ): + privacy_request.status = PrivacyRequestStatus.complete + privacy_request.save(db) + auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) + response = api_client.post(url, headers=auth_header) + assert 400 == response.status_code + assert ( + response.json()["detail"] + == f"Request failed. Cannot re-queue privacy request {privacy_request.id} with status {privacy_request.status.value}" + ) + + @mock.patch( + "fides.api.api.v1.endpoints.privacy_request_endpoints.queue_privacy_request" + ) + def test_requeue_privacy_request_from_cached_failure_point( + self, + queue_privacy_request_mock, + privacy_request, + api_client: TestClient, + generate_auth_header, + url, + ): + privacy_request.cache_failed_checkpoint_details( + step=CurrentStep.erasure, + ) + + auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) + response = api_client.post(url, headers=auth_header) + assert 200 == response.status_code + assert queue_privacy_request_mock.called + queue_privacy_request_mock.assert_called_with( + privacy_request_id=privacy_request.id, + from_step=CurrentStep.erasure.value, + ) + + @pytest.mark.usefixtures("consent_request_task") + @mock.patch( + "fides.api.api.v1.endpoints.privacy_request_endpoints.queue_privacy_request" + ) + def test_requeue_privacy_request_with_consent_tasks( + self, + queue_privacy_request_mock, + privacy_request, + api_client: TestClient, + generate_auth_header, + url, + ): + auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) + response = api_client.post(url, headers=auth_header) + assert 200 == response.status_code + assert queue_privacy_request_mock.called + queue_privacy_request_mock.assert_called_with( + privacy_request_id=privacy_request.id, + from_step=CurrentStep.consent.value, + ) + + @pytest.mark.usefixtures("erasure_request_task", "request_task") + @mock.patch( + "fides.api.api.v1.endpoints.privacy_request_endpoints.queue_privacy_request" + ) + def test_requeue_privacy_request_with_erasure_tasks( + self, + queue_privacy_request_mock, + db, + privacy_request, + api_client: TestClient, + generate_auth_header, + url, + ): + terminate_access_task = privacy_request.get_terminate_task_by_action( + ActionType.access + ) + terminate_access_task.status = ExecutionLogStatus.complete + terminate_access_task.save(db) + + auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) + response = api_client.post(url, headers=auth_header) + assert 200 == response.status_code + assert queue_privacy_request_mock.called + queue_privacy_request_mock.assert_called_with( + privacy_request_id=privacy_request.id, + from_step=CurrentStep.erasure.value, + ) + + @pytest.mark.usefixtures("erasure_request_task", "request_task") + @mock.patch( + "fides.api.api.v1.endpoints.privacy_request_endpoints.queue_privacy_request" + ) + def test_requeue_privacy_request_erasure_tasks_but_access_step_not_complete( + self, + queue_privacy_request_mock, + privacy_request, + api_client: TestClient, + generate_auth_header, + url, + ): + auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) + response = api_client.post(url, headers=auth_header) + assert 200 == response.status_code + assert queue_privacy_request_mock.called + queue_privacy_request_mock.assert_called_with( + privacy_request_id=privacy_request.id, + from_step=CurrentStep.access.value, + ) + + @pytest.mark.usefixtures("request_task") + @mock.patch( + "fides.api.api.v1.endpoints.privacy_request_endpoints.queue_privacy_request" + ) + def test_requeue_privacy_request_with_access_tasks( + self, + queue_privacy_request_mock, + privacy_request, + api_client: TestClient, + generate_auth_header, + url, + ): + auth_header = generate_auth_header(scopes=[PRIVACY_REQUEST_CALLBACK_RESUME]) + response = api_client.post(url, headers=auth_header) + assert 200 == response.status_code + assert queue_privacy_request_mock.called + queue_privacy_request_mock.assert_called_with( + privacy_request_id=privacy_request.id, + from_step=CurrentStep.access.value, + ) diff --git a/tests/ops/graph/graph_test_util.py b/tests/ops/graph/graph_test_util.py index b61150ec0e..466a6959e0 100644 --- a/tests/ops/graph/graph_test_util.py +++ b/tests/ops/graph/graph_test_util.py @@ -1,18 +1,19 @@ import random +import uuid from typing import Iterable -from fideslang.validation import FidesKey from sqlalchemy.engine import Engine from fides.api.db.base_class import FidesBase from fides.api.graph.config import * +from fides.api.graph.execution import ExecutionNode from fides.api.graph.traversal import * from fides.api.graph.traversal import Traversal, TraversalNode # to avoid having faker spam the logs from fides.api.models.connectionconfig import ConnectionConfig from fides.api.models.policy import ActionType, Policy, Rule, RuleTarget -from fides.api.models.privacy_request import PrivacyRequest +from fides.api.models.privacy_request import PrivacyRequest, RequestTask from fides.api.service.connectors import BaseConnector, MongoDBConnector from fides.api.service.connectors.sql_connector import SQLConnector from fides.api.task.graph_task import GraphTask @@ -43,12 +44,13 @@ def client(self) -> Engine: def retrieve_data( self, - node: TraversalNode, + node: ExecutionNode, policy: Policy, privacy_request: PrivacyRequest, + request_task: RequestTask, input_data: Dict[str, List[Any]], ) -> List[Row]: - return [generate_collection(node.node.collection) for _ in range(3)] + return [generate_collection(node.collection) for _ in range(3)] class MockSqlTask(GraphTask): @@ -64,20 +66,30 @@ def connector(self) -> BaseConnector: # ------------------------------------------- # test utility functions # ------------------------------------------- -def erasure_policy(*erasure_categories: str) -> Policy: +def erasure_policy(db, *erasure_categories: str) -> Policy: """Generate an erasure policy with the given categories""" - policy = Policy() - targets = [RuleTarget(data_category=c) for c in erasure_categories] - policy.rules = [ - Rule( - action_type=ActionType.erasure, - targets=targets, - masking_strategy={ + policy = Policy.create( + db=db, + data={ + "name": str(uuid.uuid4()), + "key": str(uuid.uuid4()), + }, + ) + rule = Rule.create( + db, + data={ + "action_type": ActionType.erasure, + "name": str(uuid.uuid4()), + "masking_strategy": { "strategy": "null_rewrite", "configuration": {}, }, - ) - ] + "policy_id": policy.id, + }, + ) + for c in erasure_categories: + RuleTarget.create(db, data={"data_category": c, "rule_id": rule.id}) + return policy diff --git a/tests/ops/graph/test_config.py b/tests/ops/graph/test_config.py index f3aa2171f2..cd69790dad 100644 --- a/tests/ops/graph/test_config.py +++ b/tests/ops/graph/test_config.py @@ -1,3 +1,5 @@ +import json + import pydantic import pytest @@ -74,6 +76,123 @@ def test_field_address_collection_address(self): "A", "B", "C", "D", "E" ).collection_address() == CollectionAddress("A", "B") + def test_from_string(self): + assert FieldAddress.from_string("A:B:C") == FieldAddress("A", "B", "C") + + assert FieldAddress.from_string("A:B:C.D.E") == FieldAddress( + "A", "B", "C", "D", "E" + ) + + with pytest.raises(FidesopsException): + FieldAddress.from_string("A") + + with pytest.raises(FidesopsException): + FieldAddress.from_string("A:B") + + with pytest.raises(FidesopsException): + FieldAddress.from_string("A.B") + + +collection_to_serialize = ds = Collection( + name="t3", + skip_processing=False, + fields=[ + ScalarField( + name="f1", + identity="email", + data_type_converter=StringTypeConverter(), + data_categories=["user"], + return_all_elements=False, + references=[ + (FieldAddress("a", "b", "c"), "to"), + (FieldAddress("a", "b", "d"), "from"), + ], + ), + ScalarField( + name="f2", + data_type_converter=IntTypeConverter(), + references=[(FieldAddress("d", "e", "f"), None)], + ), + ScalarField(name="f3", is_array=True, read_only=False), + ObjectField(name="f4", fields={"f5": ScalarField(name="f5")}), + ], + after={CollectionAddress("i", "j")}, + erase_after={CollectionAddress("g", "h")}, + grouped_inputs={"test_param"}, +) + +serialized_collection = { + "name": "t3", + "skip_processing": False, + "fields": [ + { + "name": "f1", + "primary_key": False, + "references": [["a:b:c", "to"], ["a:b:d", "from"]], + "identity": "email", + "data_categories": ["user"], + "data_type_converter": "string", + "return_all_elements": False, + "length": None, + "is_array": False, + "read_only": None, + }, + { + "name": "f2", + "primary_key": False, + "references": [["d:e:f", None]], + "identity": None, + "data_categories": None, + "data_type_converter": "integer", + "return_all_elements": None, + "length": None, + "is_array": False, + "read_only": None, + }, + { + "name": "f3", + "primary_key": False, + "references": [], + "identity": None, + "data_categories": None, + "data_type_converter": "None", + "return_all_elements": None, + "length": None, + "is_array": True, + "read_only": False, + }, + { + "name": "f4", + "primary_key": False, + "references": [], + "identity": None, + "data_categories": None, + "data_type_converter": "None", + "return_all_elements": None, + "length": None, + "is_array": False, + "read_only": None, + "fields": { + "f5": { + "name": "f5", + "primary_key": False, + "references": [], + "identity": None, + "data_categories": None, + "data_type_converter": "None", + "return_all_elements": None, + "length": None, + "is_array": False, + "read_only": None, + } + }, + }, + ], + "after": ["i:j"], + "erase_after": ["g:h"], + "grouped_inputs": ["test_param"], +} + class TestCollection: def test_collection_field_dict(self): @@ -235,6 +354,14 @@ def test_field_paths_by_category(self): ], # Applies to a nested field } + def test_collection_json(self): + json_collection = json.loads(collection_to_serialize.json()) + assert json_collection == serialized_collection + + def test_parse_from_task(self): + parsed = Collection.parse_from_request_task(serialized_collection) + assert parsed == collection_to_serialize + class TestField: def test_generate_field(self) -> None: diff --git a/tests/ops/graph/test_data_types.py b/tests/ops/graph/test_data_types.py index cb88183f57..11af62f22f 100644 --- a/tests/ops/graph/test_data_types.py +++ b/tests/ops/graph/test_data_types.py @@ -59,6 +59,7 @@ def test_safe_none_conversion(): def test_get_data_type_converter(): + assert isinstance(get_data_type_converter("None"), NoOpTypeConverter) assert isinstance(get_data_type_converter(None), NoOpTypeConverter) assert isinstance(get_data_type_converter(""), NoOpTypeConverter) assert isinstance(get_data_type_converter("string"), StringTypeConverter) diff --git a/tests/ops/graph/test_graph.py b/tests/ops/graph/test_graph.py index ed9cca2b10..7878433b43 100644 --- a/tests/ops/graph/test_graph.py +++ b/tests/ops/graph/test_graph.py @@ -53,10 +53,10 @@ def test_node_eq(self) -> None: def test_node_contains_field(self) -> None: node = graph.nodes[CollectionAddress("s1", "t1")] - assert node.contains_field(lambda f: f.name == "f3") - assert node.contains_field(lambda f: f.name == "f6") is False - assert node.contains_field(lambda f: f.primary_key) - assert node.contains_field(lambda f: f.identity == "ssn") + assert node.collection.contains_field(lambda f: f.name == "f3") + assert node.collection.contains_field(lambda f: f.name == "f6") is False + assert node.collection.contains_field(lambda f: f.primary_key) + assert node.collection.contains_field(lambda f: f.identity == "ssn") def test_retry_decorator(privacy_request, policy, db): @@ -69,6 +69,7 @@ def test_retry_decorator(privacy_request, policy, db): payment_card_node = traversal_nodes[ CollectionAddress("postgres_example", "payment_card") ] + execution_node = payment_card_node.to_mock_execution_node() CONFIG.execution.task_retry_count = 5 CONFIG.execution.task_retry_delay = 0.1 @@ -76,12 +77,18 @@ def test_retry_decorator(privacy_request, policy, db): class TestRetryDecorator: def __init__(self): - self.traversal_node = payment_card_node + self.execution_node = execution_node self.call_count = 0 self.start_logged = 0 self.retry_logged = 0 self.end_called_with = () - self.resources = TaskResources(privacy_request, policy, [], db) + self.resources = TaskResources( + privacy_request, + policy, + [], + payment_card_node.to_mock_request_task(), + db, + ) def log_end(self, action_type: ActionType, exc: Optional[str] = None): self.end_called_with = (action_type, exc) diff --git a/tests/ops/graph/test_graph_analytics_events.py b/tests/ops/graph/test_graph_analytics_events.py deleted file mode 100644 index 37af7e26ce..0000000000 --- a/tests/ops/graph/test_graph_analytics_events.py +++ /dev/null @@ -1,20 +0,0 @@ -from fides.api.common_exceptions import FidesopsException -from fides.api.graph.analytics_events import failed_graph_analytics_event - - -class TestFailedGraphAnalyticsEvent: - def test_create_failed_privacy_request_event(self, privacy_request): - fake_exception = FidesopsException("Graph Failed") - analytics_event = failed_graph_analytics_event(privacy_request, fake_exception) - - assert analytics_event.docker is True - assert analytics_event.event == "privacy_request_execution_failure" - assert analytics_event.event_created_at is not None - assert analytics_event.extra_data == { - "privacy_request": privacy_request.id, - } - - assert analytics_event.error == "FidesopsException" - assert analytics_event.status_code == 500 - assert analytics_event.endpoint is None - assert analytics_event.local_host is None diff --git a/tests/ops/graph/test_graph_differences.py b/tests/ops/graph/test_graph_differences.py deleted file mode 100644 index 6bf6dd6649..0000000000 --- a/tests/ops/graph/test_graph_differences.py +++ /dev/null @@ -1,918 +0,0 @@ -from typing import Any, Dict - -import pytest - -from fides.api.graph.analytics_events import prepare_rerun_graph_analytics_event -from fides.api.graph.config import ( - ROOT_COLLECTION_ADDRESS, - CollectionAddress, - FieldAddress, -) -from fides.api.graph.graph import Edge -from fides.api.graph.graph_differences import ( - GraphDiff, - GraphDiffSummary, - _find_graph_differences, - find_graph_differences_summary, - format_graph_for_caching, -) -from fides.api.graph.traversal import TraversalNode, artificial_traversal_node -from fides.api.models.connectionconfig import ConnectionConfig, ConnectionType -from fides.api.models.policy import Policy -from fides.api.schemas.policy import ActionType -from fides.api.task.graph_task import EMPTY_REQUEST, GraphTask -from fides.api.task.task_resources import TaskResources - -from ..graph.graph_test_util import generate_node - - -def build_test_traversal_env(*traversal_nodes, resources): - """For testing purposes, mock building an env which is modified in place - as part of calling traversal.traverse - - We can build a graph from the "env" variable. - """ - - env: Dict[CollectionAddress, Any] = {} - - for tn in traversal_nodes: - env[tn.address] = GraphTask(traversal_node=tn, resources=resources) - return env - - -def a_traversal_node(): - return TraversalNode( - generate_node("test_db", "a_collection", "id", "A_info", "email") - ) - - -def b_traversal_node(): - return TraversalNode( - generate_node("test_db", "b_collection", "id", "B_info", "upstream_id") - ) - - -def c_traversal_node(): - return TraversalNode( - generate_node("test_db", "c_collection", "upstream_id", "C_info") - ) - - -def d_traversal_node(): - return TraversalNode( - generate_node("test_db", "d_collection", "upstream_id", "D_info") - ) - - -@pytest.fixture(scope="module") -def resources(db): - return TaskResources( - EMPTY_REQUEST, - Policy(), - [ - ConnectionConfig( - key="mock_connection_config_key_test_db", - connection_type=ConnectionType.postgres, - ) - ], - db, - ) - - -@pytest.fixture(scope="module") -def env_a_b(resources): - """Mocks env result that is mutated as of traversal.traverse() - This mimics a simple graph where ROOT->A->B->TERMINATOR - """ - a_tn = a_traversal_node() - b_tn = b_traversal_node() - - root_node = artificial_traversal_node(ROOT_COLLECTION_ADDRESS) - root_node.add_child( - a_tn, - Edge( - FieldAddress( - ROOT_COLLECTION_ADDRESS.dataset, - ROOT_COLLECTION_ADDRESS.collection, - "email", - ), - FieldAddress("test_db", "a_collection", "email"), - ), - ) - a_tn.add_child( - b_tn, - Edge( - FieldAddress("test_db", "a_collection", "id"), - FieldAddress("test_db", "b_collection", "upstream_id"), - ), - ) - b_tn.is_terminal_node = True - - return build_test_traversal_env(a_tn, b_tn, resources=resources) - - -@pytest.fixture(scope="module") -def env_c_a_b(resources): - """Mocks env result that is mutated as of traversal.traverse() - This mimics a simple graph where ROOT->C->A->B->TERMINATOR - """ - c_tn = c_traversal_node() - a_tn = a_traversal_node() - b_tn = b_traversal_node() - - root_node = artificial_traversal_node(ROOT_COLLECTION_ADDRESS) - root_node.add_child( - c_tn, - Edge( - FieldAddress( - ROOT_COLLECTION_ADDRESS.dataset, - ROOT_COLLECTION_ADDRESS.collection, - "email", - ), - FieldAddress("test_db", "c_collection", "email"), - ), - ) - c_tn.add_child( - a_tn, - Edge( - FieldAddress("test_db", "c_collection", "id"), - FieldAddress("test_db", "a_collection", "upstream_id"), - ), - ) - a_tn.add_child( - b_tn, - Edge( - FieldAddress("test_db", "a_collection", "id"), - FieldAddress("test_db", "b_collection", "upstream_id"), - ), - ) - - return build_test_traversal_env(c_tn, a_tn, b_tn, resources=resources) - - -@pytest.fixture(scope="module") -def env_d_c_a_b(resources): - """Mocks env result that is mutated as of traversal.traverse() - This mimics a simple graph where ROOT->D->C->A->B->TERMINATOR - """ - d_tn = d_traversal_node() - c_tn = c_traversal_node() - a_tn = a_traversal_node() - b_tn = b_traversal_node() - - root_node = artificial_traversal_node(ROOT_COLLECTION_ADDRESS) - root_node.add_child( - d_tn, - Edge( - FieldAddress( - ROOT_COLLECTION_ADDRESS.dataset, - ROOT_COLLECTION_ADDRESS.collection, - "email", - ), - FieldAddress("test_db", "d_collection", "email"), - ), - ) - d_tn.add_child( - c_tn, - Edge( - FieldAddress("test_db", "d_collection", "id"), - FieldAddress("test_db", "c_collection", "upstream_id"), - ), - ) - c_tn.add_child( - a_tn, - Edge( - FieldAddress("test_db", "c_collection", "id"), - FieldAddress("test_db", "a_collection", "upstream_id"), - ), - ) - a_tn.add_child( - b_tn, - Edge( - FieldAddress("test_db", "a_collection", "id"), - FieldAddress("test_db", "b_collection", "upstream_id"), - ), - ) - - return build_test_traversal_env(d_tn, c_tn, a_tn, b_tn, resources=resources) - - -@pytest.fixture(scope="function") -def env_a_b_c(resources): - """Mocks env result that is mutated as of traversal.traverse() - This mimics a simple graph where ROOT->A->B->C->TERMINATOR - """ - c_tn = c_traversal_node() - a_tn = a_traversal_node() - b_tn = b_traversal_node() - - root_node = artificial_traversal_node(ROOT_COLLECTION_ADDRESS) - root_node.add_child( - a_tn, - Edge( - FieldAddress( - ROOT_COLLECTION_ADDRESS.dataset, - ROOT_COLLECTION_ADDRESS.collection, - "email", - ), - FieldAddress("test_db", "a_collection", "email"), - ), - ) - a_tn.add_child( - b_tn, - Edge( - FieldAddress("test_db", "a_collection", "id"), - FieldAddress("test_db", "b_collection", "upstream_id"), - ), - ) - b_tn.add_child( - c_tn, - Edge( - FieldAddress("test_db", "b_collection", "id"), - FieldAddress("test_db", "c_collection", "upstream_id"), - ), - ) - c_tn.is_terminal_node = True - - return build_test_traversal_env(a_tn, b_tn, c_tn, resources=resources) - - -@pytest.fixture(scope="function") -def env_a_c_b(resources): - """Mocks env result that is mutated as of traversal.traverse() - This mimics a simple graph where ROOT->A->C->B->TERMINATOR - """ - c_tn = c_traversal_node() - a_tn = a_traversal_node() - b_tn = b_traversal_node() - - root_node = artificial_traversal_node(ROOT_COLLECTION_ADDRESS) - root_node.add_child( - a_tn, - Edge( - FieldAddress( - ROOT_COLLECTION_ADDRESS.dataset, - ROOT_COLLECTION_ADDRESS.collection, - "email", - ), - FieldAddress("test_db", "a_collection", "email"), - ), - ) - a_tn.add_child( - c_tn, - Edge( - FieldAddress("test_db", "a_collection", "id"), - FieldAddress("test_db", "c_collection", "upstream_id"), - ), - ) - c_tn.add_child( - b_tn, - Edge( - FieldAddress("test_db", "c_collection", "id"), - FieldAddress("test_db", "b_collection", "upstream_id"), - ), - ) - b_tn.is_terminal_node = True - - return build_test_traversal_env(a_tn, c_tn, b_tn, resources=resources) - - -@pytest.fixture(scope="function") -def env_both_b_c_point_to_d(resources): - """ - Root node points to both b and c, and b and c both point to d. - - --> B --> - ROOT D - --> C --> - """ - c_tn = c_traversal_node() - d_tn = d_traversal_node() - b_tn = b_traversal_node() - - root_node = artificial_traversal_node(ROOT_COLLECTION_ADDRESS) - root_node.add_child( - b_tn, - Edge( - FieldAddress( - ROOT_COLLECTION_ADDRESS.dataset, - ROOT_COLLECTION_ADDRESS.collection, - "email", - ), - FieldAddress("test_db", "b_collection", "email"), - ), - ) - root_node.add_child( - c_tn, - Edge( - FieldAddress( - ROOT_COLLECTION_ADDRESS.dataset, - ROOT_COLLECTION_ADDRESS.collection, - "email", - ), - FieldAddress("test_db", "c_collection", "email"), - ), - ) - b_tn.add_child( - d_tn, - Edge( - FieldAddress("test_db", "b_collection", "id"), - FieldAddress("test_db", "d_collection", "upstream_id"), - ), - ) - c_tn.add_child( - d_tn, - Edge( - FieldAddress("test_db", "c_collection", "id"), - FieldAddress("test_db", "d_collection", "upstream_id"), - ), - ) - d_tn.is_terminal_node = True - return build_test_traversal_env(d_tn, c_tn, b_tn, resources=resources) - - -@pytest.fixture(scope="function") -def env_a_to_both_b_c_to_d(resources): - """ - Root node points to a which points to both b and c, - and b and c both point to d. - - --> B --> - ROOT --> A D - --> C --> - """ - a_tn = a_traversal_node() - c_tn = c_traversal_node() - d_tn = d_traversal_node() - b_tn = b_traversal_node() - - root_node = artificial_traversal_node(ROOT_COLLECTION_ADDRESS) - root_node.add_child( - a_tn, - Edge( - FieldAddress( - ROOT_COLLECTION_ADDRESS.dataset, - ROOT_COLLECTION_ADDRESS.collection, - "email", - ), - FieldAddress("test_db", "a_collection", "email"), - ), - ) - a_tn.add_child( - b_tn, - Edge( - FieldAddress("test_db", "a_collection", "id"), - FieldAddress("test_db", "b_collection", "upstream_id"), - ), - ) - a_tn.add_child( - c_tn, - Edge( - FieldAddress("test_db", "a_collection", "id"), - FieldAddress("test_db", "c_collection", "upstream_id"), - ), - ) - b_tn.add_child( - d_tn, - Edge( - FieldAddress("test_db", "b_collection", "id"), - FieldAddress("test_db", "d_collection", "upstream_id"), - ), - ) - c_tn.add_child( - d_tn, - Edge( - FieldAddress("test_db", "c_collection", "id"), - FieldAddress("test_db", "d_collection", "upstream_id"), - ), - ) - d_tn.is_terminal_node = True - return build_test_traversal_env(d_tn, c_tn, b_tn, a_tn, resources=resources) - - -class TestFormatGraphForCaching: - def test_format_graph_for_caching(self, env_a_b_c, env_a_c_b): - """Test two graphs: - - Root -> Graph A -> B -> C -> Terminator - - Root -> Graph A -> C -> B -> Terminator - - """ - - end_nodes = [c_traversal_node().address] - - assert format_graph_for_caching(env_a_b_c, end_nodes) == { - "test_db:a_collection": { - "__ROOT__:__ROOT__": [ - "__ROOT__:__ROOT__:email->test_db:a_collection:email" - ] - }, - "test_db:b_collection": { - "test_db:a_collection": [ - "test_db:a_collection:id->test_db:b_collection:upstream_id" - ] - }, - "test_db:c_collection": { - "test_db:b_collection": [ - "test_db:b_collection:id->test_db:c_collection:upstream_id" - ] - }, - "__ROOT__:__ROOT__": {}, - "__TERMINATE__:__TERMINATE__": {"test_db:c_collection": []}, - } - - # Now swap positions of b and c collection - end_nodes = [b_traversal_node().address] - assert format_graph_for_caching(env_a_c_b, end_nodes) == { - "test_db:a_collection": { - "__ROOT__:__ROOT__": [ - "__ROOT__:__ROOT__:email->test_db:a_collection:email" - ] - }, - "test_db:c_collection": { - "test_db:a_collection": [ - "test_db:a_collection:id->test_db:c_collection:upstream_id" - ] - }, - "test_db:b_collection": { - "test_db:c_collection": [ - "test_db:c_collection:id->test_db:b_collection:upstream_id" - ] - }, - "__ROOT__:__ROOT__": {}, - "__TERMINATE__:__TERMINATE__": {"test_db:b_collection": []}, - } - - -class TestGraphDiff: - def test_find_graph_differences_no_previous(self, env_a_b_c): - """Test no previous graph to compare""" - previous_graph = {} - formatted_current_graph = format_graph_for_caching( - env_a_b_c, [c_traversal_node().address] - ) - assert not _find_graph_differences( - previous_graph=previous_graph, - current_graph=formatted_current_graph, - previous_results={}, - previous_erasure_results={}, - ) - - assert not find_graph_differences_summary( - previous_graph=previous_graph, - current_graph=formatted_current_graph, - previous_results={}, - previous_erasure_results={}, - ) - - def test_find_graph_differences_no_change(self, env_a_b_c): - formatted_graph = format_graph_for_caching( - env_a_b_c, [c_traversal_node().address] - ) - graph_diff = _find_graph_differences( - previous_graph=formatted_graph, - current_graph=formatted_graph, - previous_results={}, - previous_erasure_results={}, - ) - assert graph_diff == GraphDiff( - previous_collections=[ - "test_db:a_collection", - "test_db:b_collection", - "test_db:c_collection", - ], - current_collections=[ - "test_db:a_collection", - "test_db:b_collection", - "test_db:c_collection", - ], - added_collections=[], - removed_collections=[], - added_edges=[], - removed_edges=[], - already_processed_access_collections=[], - already_processed_erasure_collections=[], - skipped_added_edges=[], - ) - assert find_graph_differences_summary( - formatted_graph, formatted_graph, {}, {} - ) == GraphDiffSummary( - prev_collection_count=3, - curr_collection_count=3, - added_collection_count=0, - removed_collection_count=0, - added_edge_count=0, - removed_edge_count=0, - already_processed_access_collection_count=0, - already_processed_erasure_collection_count=0, - skipped_added_edge_count=0, - ) - - def test_find_graph_differences_collection_added(self, env_a_b, env_a_b_c): - previous_graph = format_graph_for_caching( - env_a_b, end_nodes=[b_traversal_node().address] - ) - current_graph = format_graph_for_caching( - env_a_b_c, end_nodes=[c_traversal_node().address] - ) - graph_diff = _find_graph_differences(previous_graph, current_graph, {}, {}) - assert graph_diff == GraphDiff( - previous_collections=["test_db:a_collection", "test_db:b_collection"], - current_collections=[ - "test_db:a_collection", - "test_db:b_collection", - "test_db:c_collection", - ], - added_collections=["test_db:c_collection"], - removed_collections=[], - added_edges=["test_db:b_collection:id->test_db:c_collection:upstream_id"], - removed_edges=[], - already_processed_access_collections=[], - already_processed_erasure_collections=[], - skipped_added_edges=[], - ) - - assert find_graph_differences_summary( - previous_graph, current_graph, {}, {} - ) == GraphDiffSummary( - prev_collection_count=2, - curr_collection_count=3, - added_collection_count=1, - removed_collection_count=0, - added_edge_count=1, - removed_edge_count=0, - already_processed_access_collection_count=0, - already_processed_erasure_collection_count=0, - skipped_added_edge_count=0, - ) - - def test_find_graph_differences_collection_removed(self, env_a_b_c, env_a_b): - previous_graph = format_graph_for_caching( - env_a_b_c, end_nodes=[c_traversal_node().address] - ) - current_graph = format_graph_for_caching( - env_a_b, end_nodes=[b_traversal_node().address] - ) - graph_diff = _find_graph_differences(previous_graph, current_graph, {}, {}) - assert graph_diff == GraphDiff( - previous_collections=[ - "test_db:a_collection", - "test_db:b_collection", - "test_db:c_collection", - ], - current_collections=["test_db:a_collection", "test_db:b_collection"], - added_collections=[], - removed_collections=["test_db:c_collection"], - added_edges=[], - removed_edges=["test_db:b_collection:id->test_db:c_collection:upstream_id"], - already_processed_access_collections=[], - already_processed_erasure_collections=[], - skipped_added_edges=[], - ) - assert find_graph_differences_summary( - previous_graph, current_graph, {}, {} - ) == GraphDiffSummary( - prev_collection_count=3, - curr_collection_count=2, - added_collection_count=0, - removed_collection_count=1, - added_edge_count=0, - removed_edge_count=1, - already_processed_access_collection_count=0, - already_processed_erasure_collection_count=0, - skipped_added_edge_count=0, - ) - - def test_find_graph_differences_collection_order_changed( - self, env_a_b_c, env_a_c_b - ): - previous_graph = format_graph_for_caching( - env_a_b_c, end_nodes=[c_traversal_node().address] - ) - current_graph = format_graph_for_caching( - env_a_c_b, end_nodes=[b_traversal_node().address] - ) - previous_results = {"test_db:a_collection": []} - graph_diff = _find_graph_differences( - previous_graph, current_graph, previous_results, {} - ) - - assert graph_diff == GraphDiff( - previous_collections=[ - "test_db:a_collection", - "test_db:b_collection", - "test_db:c_collection", - ], - current_collections=[ - "test_db:a_collection", - "test_db:b_collection", - "test_db:c_collection", - ], - added_collections=[], - removed_collections=[], - added_edges=[ - "test_db:a_collection:id->test_db:c_collection:upstream_id", - "test_db:c_collection:id->test_db:b_collection:upstream_id", - ], - removed_edges=[ - "test_db:a_collection:id->test_db:b_collection:upstream_id", - "test_db:b_collection:id->test_db:c_collection:upstream_id", - ], - already_processed_access_collections=["test_db:a_collection"], - already_processed_erasure_collections=[], - skipped_added_edges=[], - ) - assert find_graph_differences_summary( - previous_graph, current_graph, previous_results, {} - ) == GraphDiffSummary( - prev_collection_count=3, - curr_collection_count=3, - added_collection_count=0, - removed_collection_count=0, - added_edge_count=2, - removed_edge_count=2, - already_processed_access_collection_count=1, - already_processed_erasure_collection_count=0, - skipped_added_edge_count=0, - ) - - def test_find_graph_differences_collection_added_upstream(self, env_a_b, env_c_a_b): - previous_graph = format_graph_for_caching( - env_a_b, end_nodes=[b_traversal_node().address] - ) - current_graph = format_graph_for_caching( - env_c_a_b, end_nodes=[b_traversal_node().address] - ) - - previous_results = {"test_db:a_collection": []} - - graph_diff = _find_graph_differences( - previous_graph, current_graph, previous_results, {} - ) - - assert graph_diff == GraphDiff( - previous_collections=["test_db:a_collection", "test_db:b_collection"], - current_collections=[ - "test_db:a_collection", - "test_db:b_collection", - "test_db:c_collection", - ], - added_collections=["test_db:c_collection"], - removed_collections=[], - added_edges=[ - "__ROOT__:__ROOT__:email->test_db:c_collection:email", - "test_db:c_collection:id->test_db:a_collection:upstream_id", - ], - removed_edges=["__ROOT__:__ROOT__:email->test_db:a_collection:email"], - already_processed_access_collections=["test_db:a_collection"], - already_processed_erasure_collections=[], - skipped_added_edges=[ - "test_db:c_collection:id->test_db:a_collection:upstream_id" - ], - ) - - assert find_graph_differences_summary( - previous_graph, current_graph, previous_results, {} - ) == GraphDiffSummary( - prev_collection_count=2, - curr_collection_count=3, - added_collection_count=1, - removed_collection_count=0, - added_edge_count=2, - removed_edge_count=1, - already_processed_access_collection_count=1, - already_processed_erasure_collection_count=0, - skipped_added_edge_count=1, - ) - - def test_find_graph_differences_collection_added_far_upstream( - self, env_d_c_a_b, env_c_a_b - ): - previous_graph = format_graph_for_caching( - env_c_a_b, end_nodes=[b_traversal_node().address] - ) - current_graph = format_graph_for_caching( - env_d_c_a_b, end_nodes=[b_traversal_node().address] - ) - - previous_results = {"test_db:a_collection": []} - - graph_diff = _find_graph_differences( - previous_graph, current_graph, previous_results, {} - ) - assert graph_diff == GraphDiff( - previous_collections=[ - "test_db:a_collection", - "test_db:b_collection", - "test_db:c_collection", - ], - current_collections=[ - "test_db:a_collection", - "test_db:b_collection", - "test_db:c_collection", - "test_db:d_collection", - ], - added_collections=["test_db:d_collection"], - removed_collections=[], - added_edges=[ - "__ROOT__:__ROOT__:email->test_db:d_collection:email", - "test_db:d_collection:id->test_db:c_collection:upstream_id", - ], - removed_edges=["__ROOT__:__ROOT__:email->test_db:c_collection:email"], - already_processed_access_collections=["test_db:a_collection"], - skipped_added_edges=[], - ) - - def test_find_graph_differences_collection_added_upstream_multiple( - self, env_both_b_c_point_to_d, env_a_to_both_b_c_to_d - ): - previous_graph = format_graph_for_caching( - env_both_b_c_point_to_d, end_nodes=[b_traversal_node().address] - ) - current_graph = format_graph_for_caching( - env_a_to_both_b_c_to_d, end_nodes=[b_traversal_node().address] - ) - - previous_results = {"test_db:b_collection": []} - - graph_diff = _find_graph_differences( - previous_graph, current_graph, previous_results, {} - ) - - assert graph_diff == GraphDiff( - previous_collections=[ - "test_db:b_collection", - "test_db:c_collection", - "test_db:d_collection", - ], - current_collections=[ - "test_db:a_collection", - "test_db:b_collection", - "test_db:c_collection", - "test_db:d_collection", - ], - added_collections=["test_db:a_collection"], - removed_collections=[], - added_edges=[ - "__ROOT__:__ROOT__:email->test_db:a_collection:email", - "test_db:a_collection:id->test_db:b_collection:upstream_id", - "test_db:a_collection:id->test_db:c_collection:upstream_id", - ], - removed_edges=[ - "__ROOT__:__ROOT__:email->test_db:b_collection:email", - "__ROOT__:__ROOT__:email->test_db:c_collection:email", - ], - already_processed_access_collections=["test_db:b_collection"], - already_processed_erasure_collections=[], - skipped_added_edges=[ - "test_db:a_collection:id->test_db:b_collection:upstream_id" - ], - ) - - def test_find_graph_differences_collection_added_upstream_erasure( - self, env_a_b, env_c_a_b - ): - previous_graph = format_graph_for_caching( - env_a_b, end_nodes=[b_traversal_node().address] - ) - current_graph = format_graph_for_caching( - env_c_a_b, end_nodes=[b_traversal_node().address] - ) - - # The original access graph has already run to get data from just a and b - previous_results = {"test_db:a_collection": [], "test_db:b_collection": []} - # The erasure has already processed one collection "a" - erasure_results = {"test_db:a_collection": 1} - - graph_diff = _find_graph_differences( - previous_graph, current_graph, previous_results, erasure_results - ) - assert graph_diff == GraphDiff( - previous_collections=["test_db:a_collection", "test_db:b_collection"], - current_collections=[ - "test_db:a_collection", - "test_db:b_collection", - "test_db:c_collection", - ], - added_collections=["test_db:c_collection"], - removed_collections=[], - added_edges=[ - "__ROOT__:__ROOT__:email->test_db:c_collection:email", - "test_db:c_collection:id->test_db:a_collection:upstream_id", - ], - removed_edges=["__ROOT__:__ROOT__:email->test_db:a_collection:email"], - already_processed_access_collections=[ - "test_db:a_collection", - "test_db:b_collection", - ], - skipped_added_edges=[ - "test_db:c_collection:id->test_db:a_collection:upstream_id" - ], - already_processed_erasure_collections=["test_db:a_collection"], - ) - - assert find_graph_differences_summary( - previous_graph, current_graph, previous_results, erasure_results - ) == GraphDiffSummary( - prev_collection_count=2, - curr_collection_count=3, - added_collection_count=1, - removed_collection_count=0, - added_edge_count=2, - removed_edge_count=1, - already_processed_access_collection_count=2, - skipped_added_edge_count=1, - already_processed_erasure_collection_count=1, - ) - - -class TestCachePrivacyRequestAccessGraph: - def test_cache_privacy_request_access_graph(self, privacy_request, env_a_b_c): - end_nodes = [c_traversal_node().address] - formatted_graph = format_graph_for_caching(env_a_b_c, end_nodes) - privacy_request.cache_access_graph(formatted_graph) - - cached_data = privacy_request.get_cached_access_graph() - assert cached_data == formatted_graph - - def test_no_access_graph_cached(self, privacy_request): - cached_data = privacy_request.get_cached_access_graph() - assert cached_data is None - - -class TestPrepareRerunAccessGraphEvent: - def test_rerun_access_graph_event_no_previous_graph( - self, privacy_request, env_a_b_c, resources - ): - end_nodes = [c_traversal_node().address] - analytics_event = prepare_rerun_graph_analytics_event( - privacy_request, env_a_b_c, end_nodes, resources, ActionType.access - ) - assert analytics_event is None - - def test_rerun_access_graph_analytics_event( - self, privacy_request, env_a_b, env_a_b_c, resources - ): - end_nodes = [b_traversal_node().address] - formatted_graph = format_graph_for_caching(env_a_b, end_nodes) - privacy_request.cache_access_graph(formatted_graph) - - end_nodes = [c_traversal_node().address] - analytics_event = prepare_rerun_graph_analytics_event( - privacy_request, env_a_b_c, end_nodes, resources, step=ActionType.access - ) - - assert analytics_event.docker is True - assert analytics_event.event == "rerun_access_graph" - assert analytics_event.event_created_at is not None - assert analytics_event.extra_data == { - "prev_collection_count": 2, - "curr_collection_count": 3, - "added_collection_count": 1, - "removed_collection_count": 0, - "added_edge_count": 1, - "removed_edge_count": 0, - "already_processed_access_collection_count": 0, - "already_processed_erasure_collection_count": 0, - "skipped_added_edge_count": 0, - "privacy_request": privacy_request.id, - } - - assert analytics_event.error is None - assert analytics_event.status_code is None - assert analytics_event.endpoint is None - assert analytics_event.local_host is None - - def test_rerun_erasure_graph_analytics_event( - self, privacy_request, env_a_b, env_a_b_c, resources - ): - end_nodes = [b_traversal_node().address] - formatted_graph = format_graph_for_caching(env_a_b, end_nodes) - privacy_request.cache_access_graph(formatted_graph) - - end_nodes = [c_traversal_node().address] - analytics_event = prepare_rerun_graph_analytics_event( - privacy_request, env_a_b_c, end_nodes, resources, step=ActionType.erasure - ) - - assert analytics_event.docker is True - assert analytics_event.event == "rerun_erasure_graph" - assert analytics_event.event_created_at is not None - assert analytics_event.extra_data == { - "prev_collection_count": 2, - "curr_collection_count": 3, - "added_collection_count": 1, - "removed_collection_count": 0, - "added_edge_count": 1, - "removed_edge_count": 0, - "already_processed_access_collection_count": 0, - "already_processed_erasure_collection_count": 0, - "skipped_added_edge_count": 0, - "privacy_request": privacy_request.id, - } - - assert analytics_event.error is None - assert analytics_event.status_code is None - assert analytics_event.endpoint is None - assert analytics_event.local_host is None diff --git a/tests/ops/integration_tests/limiter/test_rate_limiter.py b/tests/ops/integration_tests/limiter/test_rate_limiter.py index e386ba22b4..c929d65d8e 100644 --- a/tests/ops/integration_tests/limiter/test_rate_limiter.py +++ b/tests/ops/integration_tests/limiter/test_rate_limiter.py @@ -16,7 +16,6 @@ ConnectionType, ) from fides.api.models.datasetconfig import DatasetConfig -from fides.api.models.privacy_request import PrivacyRequest from fides.api.models.sql_models import Dataset as CtlDataset from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors.limiter.rate_limiter import ( @@ -25,11 +24,12 @@ RateLimiterRequest, RateLimiterTimeoutException, ) -from fides.api.task import graph_task +from fides.api.task.graph_runners import access_runner from fides.api.util.saas_util import ( load_config_with_replacement, load_dataset_with_replacement, ) +from tests.conftest import access_runner_tester @pytest.fixture @@ -220,22 +220,29 @@ def test_limiter_times_out_when_bucket_full() -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_rate_limiter_full_integration( db, + dsr_version, + request, policy, + privacy_request, stripe_connection_config, stripe_dataset_config, stripe_identity_email, ) -> None: """Test rate limiter by creating privacy request to Stripe and setting a rate limit""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + rate_limit = 1 rate_limit_config = {"limits": [{"rate": rate_limit, "period": "second"}]} stripe_connection_config.saas_config["rate_limit_config"] = rate_limit_config # set up privacy request to Stripe - privacy_request = PrivacyRequest( - id=f"test_stripe_access_request_task_{random.randint(0, 1000)}" - ) + identity = Identity(**{"email": stripe_identity_email}) privacy_request.cache_identity(identity) merged_graph = stripe_dataset_config.get_graph() @@ -244,7 +251,7 @@ async def test_rate_limiter_full_integration( # create call log spy and execute request spy = call_log_spy(Session.send) with mock.patch.object(Session, "send", spy): - await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, diff --git a/tests/ops/integration_tests/saas/connector_runner.py b/tests/ops/integration_tests/saas/connector_runner.py index 1f6c03af92..c80cf901c3 100644 --- a/tests/ops/integration_tests/saas/connector_runner.py +++ b/tests/ops/integration_tests/saas/connector_runner.py @@ -1,11 +1,13 @@ +import json import random from typing import Any, Dict, List, Optional, Tuple from sqlalchemy.orm import Session from fides.api.cryptography import cryptographic_util -from fides.api.graph.config import GraphDataset +from fides.api.graph.config import CollectionAddress, GraphDataset from fides.api.graph.graph import DatasetGraph +from fides.api.graph.traversal import Traversal, TraversalNode from fides.api.models.connectionconfig import ( AccessLevel, ConnectionConfig, @@ -20,9 +22,11 @@ ) from fides.api.models.privacy_preference_v2 import PrivacyPreferenceHistory from fides.api.models.privacy_request import ( + ExecutionLogStatus, PrivacyRequest, PrivacyRequestStatus, ProvidedIdentity, + RequestTask, ) from fides.api.models.sql_models import Dataset as CtlDataset from fides.api.schemas.policy import ActionType @@ -32,9 +36,13 @@ from fides.api.service.privacy_request.request_runner_service import ( build_consent_dataset_graph, ) -from fides.api.task import graph_task +from fides.api.task.create_request_tasks import ( + collect_tasks_fn, + persist_initial_erasure_request_tasks, + persist_new_access_request_tasks, +) from fides.api.task.graph_task import get_cached_data_for_erasures -from fides.api.util.cache import FidesopsRedis +from fides.api.util.cache import CustomJSONEncoder, FidesopsRedis from fides.api.util.collection_util import Row from fides.api.util.saas_util import ( load_config_with_replacement, @@ -94,29 +102,39 @@ async def access_request( privacy_request_id: Optional[str] = None, ) -> Dict[str, List[Row]]: """Access request for a given access policy and identities""" + + from tests.conftest import access_runner_tester + fides_key = self.connection_config.key - privacy_request = PrivacyRequest( - id=( - privacy_request_id - or f"test_{fides_key}_access_request_{random.randint(0, 1000)}" - ) + privacy_request = create_privacy_request_with_policy_rules( + access_policy, None, privacy_request_id ) + identity = Identity(**identities) privacy_request.cache_identity(identity) + graph_list = [self.dataset_config.get_graph()] + connection_config_list = [self.connection_config] + _process_external_references(self.db, graph_list, connection_config_list) + dataset_graph = DatasetGraph(*graph_list) + # cache external dataset data if self.external_references: self.cache.set_encoded_object( f"{privacy_request.id}__access_request__{self.connector_type}_external_dataset:{self.connector_type}_external_collection", [self.external_references], ) - - graph_list = [self.dataset_config.get_graph()] - connection_config_list = [self.connection_config] - _process_external_references(self.db, graph_list, connection_config_list) - dataset_graph = DatasetGraph(*graph_list) - - access_results = await graph_task.run_access_request( + if CONFIG.execution.use_dsr_3_0: + mock_external_results_3_0( + privacy_request, + dataset_graph, + identities, + self.connector_type, + self.external_references, + is_erasure=False, + ) + + access_results = access_runner_tester( privacy_request, access_policy, dataset_graph, @@ -186,10 +204,12 @@ async def old_consent_request( """ Consent requests using consent preferences on the privacy request (old workflow) """ - privacy_request = PrivacyRequest( - id=f"test_{self.connection_config.key}_old_consent_request_{random.randint(0, 1000)}", - status=PrivacyRequestStatus.pending, + from tests.conftest import consent_runner_tester + + privacy_request = create_privacy_request_with_policy_rules( + consent_policy, None, None ) + identity = Identity(**identities) privacy_request.cache_identity(identity) @@ -197,7 +217,7 @@ async def old_consent_request( {"data_use": "marketing.advertising", "opt_in": True} ] privacy_request.save(self.db) - opt_in = await graph_task.run_consent_request( + opt_in = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([self.dataset_config]), @@ -210,7 +230,7 @@ async def old_consent_request( {"data_use": "marketing.advertising", "opt_in": False} ] privacy_request.save(self.db) - opt_out = await graph_task.run_consent_request( + opt_out = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([self.dataset_config]), @@ -230,20 +250,19 @@ async def new_consent_request( """ Consent requests using privacy preference history (new workflow) """ - privacy_request = PrivacyRequest( - id=( - privacy_request_id - or f"test_{self.connection_config.key}_new_consent_request_{random.randint(0, 1000)}" - ), - status=PrivacyRequestStatus.pending, + from tests.conftest import consent_runner_tester + + privacy_request = create_privacy_request_with_policy_rules( + consent_policy, None, privacy_request_id ) + privacy_request.save(self.db) identity = Identity(**identities) privacy_request.cache_identity(identity) _privacy_preference_history(self.db, privacy_request, identities, opt_in=True) - opt_in = await graph_task.run_consent_request( + opt_in = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([self.dataset_config]), @@ -253,7 +272,7 @@ async def new_consent_request( ) _privacy_preference_history(self.db, privacy_request, identities, opt_in=False) - opt_out = await graph_task.run_consent_request( + opt_out = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([self.dataset_config]), @@ -261,7 +280,6 @@ async def new_consent_request( identities, self.db, ) - return {"opt_in": opt_in.popitem()[1], "opt_out": opt_out.popitem()[1]} async def _base_erasure_request( @@ -271,29 +289,42 @@ async def _base_erasure_request( identities: Dict[str, Any], privacy_request_id: Optional[str] = None, ) -> Tuple[Dict, Dict]: + from tests.conftest import access_runner_tester, erasure_runner_tester + fides_key = self.connection_config.key - privacy_request = PrivacyRequest( - id=( - privacy_request_id - or f"test_{fides_key}_access_request_{random.randint(0, 1000)}" - ) + + privacy_request = create_privacy_request_with_policy_rules( + access_policy, erasure_policy, privacy_request_id ) identity = Identity(**identities) privacy_request.cache_identity(identity) + graph_list = [self.dataset_config.get_graph()] + connection_config_list = [self.connection_config] + _process_external_references(self.db, graph_list, connection_config_list) + dataset_graph = DatasetGraph(*graph_list) + # cache external dataset data if self.erasure_external_references: + # DSR 2.0 self.cache.set_encoded_object( f"{privacy_request.id}__access_request__{self.connector_type}_external_dataset:{self.connector_type}_external_collection", [self.erasure_external_references], ) - - graph_list = [self.dataset_config.get_graph()] - connection_config_list = [self.connection_config] - _process_external_references(self.db, graph_list, connection_config_list) - dataset_graph = DatasetGraph(*graph_list) - - access_results = await graph_task.run_access_request( + # DSR 3.0 + if CONFIG.execution.use_dsr_3_0: + # DSR 3.0 does not pull its results out of the cache, but rather + # off of the Request Tasks - + mock_external_results_3_0( + privacy_request, + dataset_graph, + identities, + self.connector_type, + self.erasure_external_references, + is_erasure=True, + ) + + access_results = access_runner_tester( privacy_request, access_policy, dataset_graph, @@ -312,7 +343,7 @@ async def _base_erasure_request( access_results[f"{fides_key}:{collection['name']}"] ), f"No rows returned for collection '{collection['name']}'" - erasure_results = await graph_task.run_erasure( + erasure_results = erasure_runner_tester( privacy_request, erasure_policy, dataset_graph, @@ -325,6 +356,42 @@ async def _base_erasure_request( return access_results, erasure_results +def create_privacy_request_with_policy_rules( + access_or_consent_policy: Policy, # In the event only one policy is passed in + erasure_policy: Optional[ + Policy + ] = None, # If two policies are passed in, second goes here. + privacy_request_id: Optional[str] = None, +) -> PrivacyRequest: + """ + Create a proper Privacy Request with a single Policy by combining policy rules passed in for this particular + test and persist to the database + + Privacy Requests have only one policy, but tests using the connector runner can pass in policies separately. + DSR 3.0 scheduler requires that Privacy Requests are formulated properly and persisted to the database. + """ + session = Session.object_session(access_or_consent_policy) + + if erasure_policy: + for rule in erasure_policy.rules: + # Move the erasure rules over to the access policy if applicable, so one Policy holds + # all of the rules + rule.policy_id = access_or_consent_policy.id + rule.save(session) + + privacy_request = PrivacyRequest.create( + db=session, + data={ + "policy_id": access_or_consent_policy.id, + "status": PrivacyRequestStatus.in_processing, + }, + ) + if privacy_request_id: + privacy_request.id = privacy_request_id + privacy_request.save(session) + return privacy_request + + def _config(connector_type: str) -> Dict[str, Any]: return load_config_with_replacement( f"data/saas/config/{connector_type}_config.yml", @@ -511,3 +578,49 @@ def generate_random_phone_number() -> str: Generate a random phone number in the format of E.164, +1112223333 """ return f"+{random.randrange(100,999)}555{random.randrange(1000,9999)}" + + +def mock_external_results_3_0( + privacy_request: PrivacyRequest, + dataset_graph: DatasetGraph, + identities: Dict[str, Any], + connector_type: ConnectionType, + external_references: Dict[str, Any], + is_erasure: bool, +): + """ + Mock external results for DSR 3.0 by going ahead and building the Request Tasks up front and caching the + external results on the appropriate external Request Task + """ + session = Session.object_session(privacy_request) + traversal: Traversal = Traversal(dataset_graph, identities) + traversal_nodes: Dict[CollectionAddress, TraversalNode] = {} + end_nodes: List[CollectionAddress] = traversal.traverse( + traversal_nodes, collect_tasks_fn + ) + persist_new_access_request_tasks( + Session.object_session(privacy_request), + privacy_request, + traversal, + traversal_nodes, + end_nodes, + dataset_graph, + ) + external_request_task = privacy_request.access_tasks.filter( + RequestTask.collection_address + == f"{connector_type}_external_dataset:{connector_type}_external_collection" + ).first() + external_request_task.access_data = json.dumps( + [external_references], cls=CustomJSONEncoder + ) + external_request_task.data_for_erasures = json.dumps( + [external_references], cls=CustomJSONEncoder + ) + external_request_task.save(session) + external_request_task.update_status(session, ExecutionLogStatus.complete) + erasure_end_nodes: List[CollectionAddress] = list(dataset_graph.nodes.keys()) + # Further, erasure tasks are typically built when access tasks are created so the graphs match + if is_erasure: + persist_initial_erasure_request_tasks( + session, privacy_request, traversal_nodes, erasure_end_nodes, dataset_graph + ) diff --git a/tests/ops/integration_tests/saas/request_override/test_firebase_auth_task.py b/tests/ops/integration_tests/saas/request_override/test_firebase_auth_task.py index 38ff50d360..e6c582700b 100644 --- a/tests/ops/integration_tests/saas/request_override/test_firebase_auth_task.py +++ b/tests/ops/integration_tests/saas/request_override/test_firebase_auth_task.py @@ -1,34 +1,39 @@ -from uuid import uuid4 - import pytest from firebase_admin import auth from firebase_admin.auth import UserNotFoundError, UserRecord from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.saas_request.override_implementations.firebase_auth_request_overrides import ( firebase_auth_user_delete, initialize_firebase, ) -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match +from tests.ops.test_helpers.cache_secrets_helper import clear_cache_identities @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_firebase_auth_access_request( db, + privacy_request, policy, + dsr_version, + request, firebase_auth_connection_config, firebase_auth_dataset_config, firebase_auth_user: auth.ImportUserRecord, ) -> None: """Full access request based on the Firebase Auth SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest(id=f"test_firebase_access_request_task_{uuid4()}") identity = Identity(**{"email": firebase_auth_user.email}) privacy_request.cache_identity(identity) @@ -36,7 +41,7 @@ async def test_firebase_auth_access_request( merged_graph = firebase_auth_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -96,30 +101,54 @@ async def test_firebase_auth_access_request( @pytest.mark.asyncio @pytest.mark.usefixtures("firebase_auth_user") @pytest.mark.parametrize( - "identity_info, message", + "identity_info, message, dsr_version", [ - ({"email": "a_fake_email@ethyca.com"}, "Could not find user with email"), - ({"phone_number": "+10000000000"}, "Could not find user with phone_number"), + ( + {"email": "a_fake_email@ethyca.com"}, + "Could not find user with email", + "use_dsr_3_0", + ), + ( + {"phone_number": "+10000000000"}, + "Could not find user with phone_number", + "use_dsr_3_0", + ), + ( + {"email": "a_fake_email@ethyca.com"}, + "Could not find user with email", + "use_dsr_2_0", + ), + ( + {"phone_number": "+10000000000"}, + "Could not find user with phone_number", + "use_dsr_2_0", + ), ], ) async def test_firebase_auth_access_request_non_existent_users( identity_info, message, + dsr_version, db, + request, + privacy_request, policy, firebase_auth_connection_config, firebase_auth_dataset_config, loguru_caplog, ) -> None: """Ensure that firebase access request task gracefully handles non-existent users""" - privacy_request = PrivacyRequest(id=f"test_firebase_access_request_task_{uuid4()}") + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + clear_cache_identities(privacy_request.id) + identity = Identity(**identity_info) privacy_request.cache_identity(identity) + dataset_name = firebase_auth_connection_config.get_saas_config().fides_key merged_graph = firebase_auth_dataset_config.get_graph() graph = DatasetGraph(merged_graph) # just ensure we don't error out here - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -146,16 +175,24 @@ async def test_firebase_auth_access_request_non_existent_users( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_firebase_auth_access_request_phone_number_identity( db, policy, + dsr_version, + request, + privacy_request, firebase_auth_connection_config, firebase_auth_dataset_config, firebase_auth_user: auth.ImportUserRecord, ) -> None: """Full access request based on the Firebase Auth SaaS config using a phone number identity""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + clear_cache_identities(privacy_request.id) - privacy_request = PrivacyRequest(id=f"test_firebase_access_request_task_{uuid4()}") identity = Identity(**{"phone_number": firebase_auth_user.phone_number}) privacy_request.cache_identity(identity) @@ -163,7 +200,7 @@ async def test_firebase_auth_access_request_phone_number_identity( merged_graph = firebase_auth_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -224,9 +261,15 @@ async def test_firebase_auth_access_request_phone_number_identity( ) @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_firebase_auth_update_request( db, - policy, + dsr_version, + request, + privacy_request, firebase_auth_connection_config, firebase_auth_dataset_config, firebase_auth_user: auth.ImportUserRecord, @@ -234,8 +277,11 @@ async def test_firebase_auth_update_request( firebase_auth_secrets, ) -> None: """Update request based on the Firebase Auth SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) - privacy_request = PrivacyRequest(id=f"test_firebase_update_request_task_{uuid4()}") identity = Identity(**{"email": firebase_auth_user.email}) privacy_request.cache_identity(identity) @@ -243,9 +289,9 @@ async def test_firebase_auth_update_request( merged_graph = firebase_auth_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [firebase_auth_connection_config], {"email": firebase_auth_user.email}, @@ -264,7 +310,7 @@ async def test_firebase_auth_update_request( ], ) - await graph_task.run_erasure( + erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, @@ -313,9 +359,15 @@ async def test_firebase_auth_update_request( ) @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_firebase_auth_update_request_phone_number_identity( db, - policy, + dsr_version, + request, + privacy_request, firebase_auth_connection_config, firebase_auth_dataset_config, firebase_auth_user: auth.ImportUserRecord, @@ -323,8 +375,11 @@ async def test_firebase_auth_update_request_phone_number_identity( firebase_auth_secrets, ) -> None: """Update request based on the Firebase Auth SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) - privacy_request = PrivacyRequest(id=f"test_firebase_update_request_task_{uuid4()}") identity = Identity(**{"phone_number": firebase_auth_user.phone_number}) privacy_request.cache_identity(identity) @@ -332,9 +387,9 @@ async def test_firebase_auth_update_request_phone_number_identity( merged_graph = firebase_auth_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [firebase_auth_connection_config], {"phone_number": firebase_auth_user.phone_number}, @@ -353,7 +408,7 @@ async def test_firebase_auth_update_request_phone_number_identity( ], ) - await graph_task.run_erasure( + erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, @@ -399,9 +454,15 @@ async def test_firebase_auth_update_request_phone_number_identity( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_firebase_auth_delete_request( db, - policy, + dsr_version, + request, + privacy_request, firebase_auth_connection_config, firebase_auth_dataset_config, firebase_auth_user: auth.ImportUserRecord, @@ -409,8 +470,11 @@ async def test_firebase_auth_delete_request( firebase_auth_secrets, ) -> None: """Delete request based on the Firebase Auth SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) - privacy_request = PrivacyRequest(id=f"test_firebase_delete_request_task_{uuid4()}") identity = Identity(**{"email": firebase_auth_user.email}) privacy_request.cache_identity(identity) @@ -418,9 +482,9 @@ async def test_firebase_auth_delete_request( merged_graph = firebase_auth_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [firebase_auth_connection_config], {"email": firebase_auth_user.email}, @@ -442,7 +506,7 @@ async def test_firebase_auth_delete_request( masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, @@ -470,9 +534,15 @@ async def test_firebase_auth_delete_request( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_firebase_auth_delete_request_phone_number_identity( db, - policy, + dsr_version, + request, + privacy_request, firebase_auth_connection_config, firebase_auth_dataset_config, firebase_auth_user: auth.ImportUserRecord, @@ -480,8 +550,12 @@ async def test_firebase_auth_delete_request_phone_number_identity( firebase_auth_secrets, ) -> None: """Delete request based on the Firebase Auth SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + clear_cache_identities(privacy_request.id) + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) - privacy_request = PrivacyRequest(id=f"test_firebase_delete_request_task_{uuid4()}") identity = Identity(**{"phone_number": firebase_auth_user.phone_number}) privacy_request.cache_identity(identity) @@ -489,9 +563,9 @@ async def test_firebase_auth_delete_request_phone_number_identity( merged_graph = firebase_auth_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [firebase_auth_connection_config], {"phone_number": firebase_auth_user.phone_number}, @@ -513,7 +587,7 @@ async def test_firebase_auth_delete_request_phone_number_identity( masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, @@ -542,9 +616,15 @@ async def test_firebase_auth_delete_request_phone_number_identity( @pytest.mark.integration_saas @pytest.mark.integration_saas_override @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_firebase_auth_user_delete_function( db, - policy, + dsr_version, + request, + privacy_request, firebase_auth_connection_config, firebase_auth_dataset_config, firebase_auth_user: auth.UserRecord, @@ -552,7 +632,11 @@ async def test_firebase_auth_user_delete_function( firebase_auth_secrets, ) -> None: """Tests delete functionality by explicitly invoking the delete override function""" - privacy_request = PrivacyRequest(id=f"test_firebase_delete_request_task_{uuid4()}") + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) + identity = Identity(**{"email": firebase_auth_user.email}) privacy_request.cache_identity(identity) @@ -580,9 +664,15 @@ async def test_firebase_auth_user_delete_function( @pytest.mark.integration_saas @pytest.mark.integration_saas_override @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_firebase_auth_user_delete_function_with_phone_number_identity( db, - policy, + dsr_version, + request, + privacy_request, firebase_auth_connection_config, firebase_auth_dataset_config, firebase_auth_user: auth.UserRecord, @@ -590,7 +680,12 @@ async def test_firebase_auth_user_delete_function_with_phone_number_identity( firebase_auth_secrets, ) -> None: """Tests delete functionality by explicitly invoking the delete override function""" - privacy_request = PrivacyRequest(id=f"test_firebase_delete_request_task_{uuid4()}") + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + clear_cache_identities(privacy_request.id) + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) + identity = Identity(**{"phone_number": firebase_auth_user.phone_number}) privacy_request.cache_identity(identity) diff --git a/tests/ops/integration_tests/saas/request_override/test_mailchimp_override_task.py b/tests/ops/integration_tests/saas/request_override/test_mailchimp_override_task.py index 9a99e25208..087dc48710 100644 --- a/tests/ops/integration_tests/saas/request_override/test_mailchimp_override_task.py +++ b/tests/ops/integration_tests/saas/request_override/test_mailchimp_override_task.py @@ -1,12 +1,9 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match """ @@ -30,18 +27,23 @@ @pytest.mark.integration_saas @pytest.mark.integration_saas_override @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_mailchimp_override_access_request_task( db, + privacy_request, + dsr_version, + request, policy, mailchimp_override_connection_config, mailchimp_override_dataset_config, mailchimp_identity_email, ) -> None: """Full access request based on the Mailchimp SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_mailchimp_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": mailchimp_identity_email}) privacy_request.cache_identity(identity) @@ -49,7 +51,7 @@ async def test_mailchimp_override_access_request_task( merged_graph = mailchimp_override_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -108,9 +110,15 @@ async def test_mailchimp_override_access_request_task( @pytest.mark.integration_saas @pytest.mark.integration_saas_override @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_mailchimp_erasure_request_task( db, - policy, + dsr_version, + request, + privacy_request, erasure_policy_string_rewrite, mailchimp_override_connection_config, mailchimp_override_dataset_config, @@ -118,10 +126,11 @@ async def test_mailchimp_erasure_request_task( reset_override_mailchimp_data, ) -> None: """Full erasure request based on the Mailchimp SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) - privacy_request = PrivacyRequest( - id=f"test_mailchimp_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": mailchimp_identity_email}) privacy_request.cache_identity(identity) @@ -129,16 +138,16 @@ async def test_mailchimp_erasure_request_task( merged_graph = mailchimp_override_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - await graph_task.run_access_request( + access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [mailchimp_override_connection_config], {"email": mailchimp_identity_email}, db, ) - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_ada_chatbot_task.py b/tests/ops/integration_tests/saas/test_ada_chatbot_task.py index 7d04c1d63c..69ea20c0a7 100644 --- a/tests/ops/integration_tests/saas/test_ada_chatbot_task.py +++ b/tests/ops/integration_tests/saas/test_ada_chatbot_task.py @@ -9,13 +9,21 @@ class TestAdaChatbotConnector: def test_connection(self, ada_chatbot_runner: ConnectorRunner): ada_chatbot_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request( self, ada_chatbot_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, ada_chatbot_erasure_identity_email: str, + dsr_version, + request, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + ( _, erasure_results, diff --git a/tests/ops/integration_tests/saas/test_adobe_campaign_task.py b/tests/ops/integration_tests/saas/test_adobe_campaign_task.py index 2f37797797..36bf3b64e5 100644 --- a/tests/ops/integration_tests/saas/test_adobe_campaign_task.py +++ b/tests/ops/integration_tests/saas/test_adobe_campaign_task.py @@ -1,14 +1,11 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match @@ -21,18 +18,23 @@ def test_adobe_campaign_connection_test(adobe_campaign_connection_config) -> Non @pytest.mark.skip(reason="Only staging credentials available") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_adobe_campaign_access_request_task( policy, adobe_campaign_identity_email, adobe_campaign_connection_config, adobe_campaign_dataset_config, + dsr_version, + request, + privacy_request, db, ) -> None: """Full access request based on the Adobe Campaign SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_adobe_campaign_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": adobe_campaign_identity_email}) privacy_request.cache_identity(identity) @@ -40,7 +42,7 @@ async def test_adobe_campaign_access_request_task( merged_graph = adobe_campaign_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -159,34 +161,39 @@ async def test_adobe_campaign_access_request_task( @pytest.mark.skip(reason="Only staging credentials available") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_adobe_campaign_erasure_request_task( db, - policy, + erasure_policy, + privacy_request_with_erasure_policy, adobe_campaign_connection_config, adobe_campaign_dataset_config, adobe_campaign_erasure_identity_email, adobe_campaign_erasure_data, + dsr_version, + request, ) -> None: """Full erasure request based on the Adobe Campaign SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False # Allow GDPR Delete # Create user for GDPR delete erasure_email = adobe_campaign_erasure_identity_email - privacy_request = PrivacyRequest( - id=f"test_adobe_campaign_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": erasure_email}) - privacy_request.cache_identity(identity) + privacy_request_with_erasure_policy.cache_identity(identity) dataset_name = adobe_campaign_connection_config.get_saas_config().fides_key merged_graph = adobe_campaign_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( - privacy_request, - policy, + v = access_runner_tester( + privacy_request_with_erasure_policy, + erasure_policy, graph, [adobe_campaign_connection_config], {"email": erasure_email}, @@ -293,13 +300,13 @@ async def test_adobe_campaign_erasure_request_task( ], ) - x = await graph_task.run_erasure( - privacy_request, - policy, + x = erasure_runner_tester( + privacy_request_with_erasure_policy, + erasure_policy, graph, [adobe_campaign_connection_config], {"email": erasure_email}, - get_cached_data_for_erasures(privacy_request.id), + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), db, ) diff --git a/tests/ops/integration_tests/saas/test_adobe_sign_task.py b/tests/ops/integration_tests/saas/test_adobe_sign_task.py index 12b53ea3b6..290b49cb98 100644 --- a/tests/ops/integration_tests/saas/test_adobe_sign_task.py +++ b/tests/ops/integration_tests/saas/test_adobe_sign_task.py @@ -9,9 +9,20 @@ class TestAdobeSignConnector: def test_connection(self, adobe_sign_runner: ConnectorRunner): adobe_sign_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_access_request( - self, adobe_sign_runner: ConnectorRunner, policy, adobe_sign_identity_email: str + self, + dsr_version, + request, + adobe_sign_runner: ConnectorRunner, + policy, + adobe_sign_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + access_results = await adobe_sign_runner.access_request( access_policy=policy, identities={"email": adobe_sign_identity_email} ) diff --git a/tests/ops/integration_tests/saas/test_adyen_task.py b/tests/ops/integration_tests/saas/test_adyen_task.py index dd56be9e7d..20b9f249c4 100644 --- a/tests/ops/integration_tests/saas/test_adyen_task.py +++ b/tests/ops/integration_tests/saas/test_adyen_task.py @@ -9,13 +9,20 @@ class TestAdyenConnector: def test_connection(self, adyen_runner: ConnectorRunner): adyen_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request( self, adyen_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, adyen_erasure_identity_email: str, + dsr_version, + request, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 ( _, erasure_results, @@ -24,6 +31,7 @@ async def test_non_strict_erasure_request( erasure_policy=erasure_policy_string_rewrite, identities={"email": adyen_erasure_identity_email}, ) + assert erasure_results == { "adyen_external_dataset:adyen_external_collection": 0, "adyen_instance:user": 1, diff --git a/tests/ops/integration_tests/saas/test_aircall_task.py b/tests/ops/integration_tests/saas/test_aircall_task.py index 441a012307..4f796cdbab 100644 --- a/tests/ops/integration_tests/saas/test_aircall_task.py +++ b/tests/ops/integration_tests/saas/test_aircall_task.py @@ -1,14 +1,11 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import get_config +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match from tests.ops.test_helpers.saas_test_utils import poll_for_existence @@ -22,18 +19,23 @@ def test_aircall_connection_test(aircall_connection_config) -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_aircall_access_request_task_with_phone_number( db, policy, + dsr_version, + request, + privacy_request, aircall_connection_config, aircall_dataset_config, aircall_identity_phone_number, ) -> None: """Full access request based on the Aircall SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_aircall_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"phone_number": aircall_identity_phone_number}) privacy_request.cache_identity(identity) @@ -41,7 +43,7 @@ async def test_aircall_access_request_task_with_phone_number( merged_graph = aircall_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -78,9 +80,15 @@ async def test_aircall_access_request_task_with_phone_number( @pytest.mark.skip(reason="Temporarily disabled test") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_aircall_erasure_request_task( db, - policy, + privacy_request, + dsr_version, + request, erasure_policy_string_rewrite, aircall_connection_config, aircall_dataset_config, @@ -90,13 +98,14 @@ async def test_aircall_erasure_request_task( aircall_test_client, ) -> None: """Full erasure request based on the Aircall SaaS config""" + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) + + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False # Allow Delete - privacy_request = PrivacyRequest( - id=f"test_aircall_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"phone_number": aircall_erasure_identity_phone_number}) privacy_request.cache_identity(identity) @@ -104,9 +113,9 @@ async def test_aircall_erasure_request_task( merged_graph = aircall_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [aircall_connection_config], {"phone_number": aircall_erasure_identity_phone_number}, @@ -131,7 +140,7 @@ async def test_aircall_erasure_request_task( ], ) - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_amplitude_task.py b/tests/ops/integration_tests/saas/test_amplitude_task.py index fc9ed1e8e9..4f376e6851 100644 --- a/tests/ops/integration_tests/saas/test_amplitude_task.py +++ b/tests/ops/integration_tests/saas/test_amplitude_task.py @@ -1,14 +1,11 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import get_config +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match CONFIG = get_config() @@ -21,18 +18,23 @@ def test_amplitude_connection_test(amplitude_connection_config) -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_amplitude_access_request_task( db, policy, + dsr_version, + request, + privacy_request, amplitude_connection_config, amplitude_dataset_config, amplitude_identity_email, ) -> None: """Full access request based on the Amplitude SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_amplitude_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": amplitude_identity_email}) privacy_request.cache_identity(identity) @@ -40,7 +42,7 @@ async def test_amplitude_access_request_task( merged_graph = amplitude_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -174,9 +176,15 @@ async def test_amplitude_access_request_task( @pytest.mark.skip(reason="Temporarily disabled test") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_amplitude_erasure_request_task( db, - policy, + dsr_version, + request, + privacy_request, erasure_policy_string_rewrite, amplitude_connection_config, amplitude_dataset_config, @@ -184,13 +192,14 @@ async def test_amplitude_erasure_request_task( amplitude_create_erasure_data, ) -> None: """Full erasure request based on the Amplitude SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False # Allow Delete - privacy_request = PrivacyRequest( - id=f"test_amplitude_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": amplitude_erasure_identity_email}) privacy_request.cache_identity(identity) @@ -198,9 +207,9 @@ async def test_amplitude_erasure_request_task( merged_graph = amplitude_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [amplitude_connection_config], {"email": amplitude_erasure_identity_email}, @@ -256,7 +265,7 @@ async def test_amplitude_erasure_request_task( ], ) - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_appsflyer_task.py b/tests/ops/integration_tests/saas/test_appsflyer_task.py index e5db12c0db..af906d90a6 100644 --- a/tests/ops/integration_tests/saas/test_appsflyer_task.py +++ b/tests/ops/integration_tests/saas/test_appsflyer_task.py @@ -9,17 +9,34 @@ class TestAppsFlyerConnector: def test_connection(self, appsflyer_runner: ConnectorRunner): appsflyer_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_access_request( - self, appsflyer_runner: ConnectorRunner, policy, appsflyer_identity_email: str + self, + appsflyer_runner: ConnectorRunner, + policy, + dsr_version, + request, + appsflyer_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + access_results = await appsflyer_runner.access_request( access_policy=policy, identities={"email": appsflyer_identity_email} ) assert len(access_results["appsflyer_instance:user"]) == 10 + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request( self, + dsr_version, + request, appsflyer_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, diff --git a/tests/ops/integration_tests/saas/test_auth0_task.py b/tests/ops/integration_tests/saas/test_auth0_task.py index 15b6af0954..68d7193568 100644 --- a/tests/ops/integration_tests/saas/test_auth0_task.py +++ b/tests/ops/integration_tests/saas/test_auth0_task.py @@ -1,14 +1,11 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.fixtures.saas.auth0_fixtures import _user_exists from tests.ops.graph.graph_test_util import assert_rows_match from tests.ops.test_helpers.saas_test_utils import poll_for_existence @@ -20,6 +17,10 @@ def test_auth0_connection_test(auth0_connection_config) -> None: @pytest.mark.integration_saas +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_auth0_access_request_task( db, policy, @@ -27,12 +28,13 @@ async def test_auth0_access_request_task( auth0_dataset_config, auth0_identity_email, auth0_access_data, + privacy_request, + dsr_version, + request, ) -> None: """Full access request based on the Auth0 SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_auth0_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": auth0_identity_email}) privacy_request.cache_identity(identity) @@ -40,7 +42,7 @@ async def test_auth0_access_request_task( merged_graph = auth0_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -88,23 +90,30 @@ async def test_auth0_access_request_task( @pytest.mark.integration_saas +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_auth0_erasure_request_task( db, - policy, erasure_policy_string_rewrite, auth0_connection_config, auth0_dataset_config, auth0_erasure_identity_email, auth0_erasure_data, auth0_token, + privacy_request_with_erasure_policy, + dsr_version, + request, ) -> None: """Full erasure request based on the Auth0 SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request_with_erasure_policy.policy_id = erasure_policy_string_rewrite.id + privacy_request_with_erasure_policy.save(db) - privacy_request = PrivacyRequest( - id=f"test_auth0_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": auth0_erasure_identity_email}) - privacy_request.cache_identity(identity) + privacy_request_with_erasure_policy.cache_identity(identity) dataset_name = auth0_connection_config.get_saas_config().fides_key merged_graph = auth0_dataset_config.get_graph() @@ -113,9 +122,9 @@ async def test_auth0_erasure_request_task( temp_masking = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False - v = await graph_task.run_access_request( - privacy_request, - policy, + v = access_runner_tester( + privacy_request_with_erasure_policy, + erasure_policy_string_rewrite, graph, [auth0_connection_config], {"email": auth0_erasure_identity_email}, @@ -140,13 +149,13 @@ async def test_auth0_erasure_request_task( ], ) - x = await graph_task.run_erasure( - privacy_request, + x = erasure_runner_tester( + privacy_request_with_erasure_policy, erasure_policy_string_rewrite, graph, [auth0_connection_config], {"email": auth0_erasure_identity_email}, - get_cached_data_for_erasures(privacy_request.id), + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), db, ) assert x == { diff --git a/tests/ops/integration_tests/saas/test_braintree_task.py b/tests/ops/integration_tests/saas/test_braintree_task.py index cbd9d826fe..6cf85dbe74 100644 --- a/tests/ops/integration_tests/saas/test_braintree_task.py +++ b/tests/ops/integration_tests/saas/test_braintree_task.py @@ -1,15 +1,13 @@ import logging -import random import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match logger = logging.getLogger(__name__) @@ -24,20 +22,26 @@ def test_braintree_connection_test( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_braintree_access_request_task( db, policy, + dsr_version, + request, braintree_connection_config, braintree_dataset_config, braintree_identity_email, connection_config, + privacy_request, braintree_postgres_dataset_config, braintree_postgres_db, ) -> None: """Full access request based on the Braintree Conversations SaaS config""" - privacy_request = PrivacyRequest( - id=f"test_braintree_access_request_task_{random.randint(0, 1000)}" - ) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + identity_attribute = "email" identity_value = braintree_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -48,7 +52,7 @@ async def test_braintree_access_request_task( merged_graph = braintree_dataset_config.get_graph() graph = DatasetGraph(*[merged_graph, braintree_postgres_dataset_config.get_graph()]) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -81,9 +85,15 @@ async def test_braintree_access_request_task( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_braintree_erasure_request_task( db, - policy, + dsr_version, + request, + privacy_request, braintree_connection_config, braintree_dataset_config, connection_config, @@ -94,10 +104,11 @@ async def test_braintree_erasure_request_task( braintree_postgres_erasure_db, ) -> None: """Full erasure request based on the Braintree SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) - privacy_request = PrivacyRequest( - id=f"test_braintree_access_request_task_{random.randint(0, 1000)}" - ) identity_attribute = "email" identity_value = braintree_erasure_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -108,9 +119,9 @@ async def test_braintree_erasure_request_task( merged_graph = braintree_dataset_config.get_graph() graph = DatasetGraph(*[merged_graph, braintree_postgres_dataset_config.get_graph()]) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [braintree_connection_config, connection_config], {"email": braintree_erasure_identity_email}, @@ -126,7 +137,7 @@ async def test_braintree_erasure_request_task( temp_masking = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = True - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_braze_task.py b/tests/ops/integration_tests/saas/test_braze_task.py index c437b44a12..bf05a5798e 100644 --- a/tests/ops/integration_tests/saas/test_braze_task.py +++ b/tests/ops/integration_tests/saas/test_braze_task.py @@ -1,16 +1,14 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.fixtures.saas.braze_fixtures import _user_exists from tests.ops.graph.graph_test_util import assert_rows_match +from tests.ops.test_helpers.cache_secrets_helper import clear_cache_identities from tests.ops.test_helpers.saas_test_utils import poll_for_existence @@ -21,18 +19,23 @@ def test_braze_connection_test(braze_connection_config) -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_braze_access_request_task_with_email( db, policy, + dsr_version, + request, + privacy_request, braze_connection_config, braze_dataset_config, braze_identity_email, ) -> None: """Full access request based on the Braze SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_braze_access_request_task_{random.randint(0, 250)}" - ) identity_attribute = "email" identity_value = braze_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -43,7 +46,7 @@ async def test_braze_access_request_task_with_email( merged_graph = braze_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -95,19 +98,24 @@ async def test_braze_access_request_task_with_email( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_braze_access_request_task_with_phone_number( db, policy, + dsr_version, + request, + privacy_request, braze_connection_config, braze_dataset_config, - braze_identity_email, braze_identity_phone_number, ) -> None: """Full access request based on the Braze SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + clear_cache_identities(privacy_request.id) - privacy_request = PrivacyRequest( - id=f"test_braze_access_request_task_{random.randint(0, 1000)}" - ) identity_kwargs = {"phone_number": braze_identity_phone_number} identity = Identity(**identity_kwargs) privacy_request.cache_identity(identity) @@ -116,7 +124,7 @@ async def test_braze_access_request_task_with_phone_number( merged_graph = braze_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -152,9 +160,15 @@ async def test_braze_access_request_task_with_phone_number( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_braze_erasure_request_task( db, - policy, + dsr_version, + request, + privacy_request, erasure_policy_string_rewrite_name_and_email, braze_connection_config, braze_dataset_config, @@ -162,10 +176,11 @@ async def test_braze_erasure_request_task( braze_erasure_data, ) -> None: """Full erasure request based on the Braze SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite_name_and_email.id + privacy_request.save(db) - privacy_request = PrivacyRequest( - id=f"test_braze_erasure_request_task_{random.randint(0, 1000)}" - ) identity_attribute = "email" identity_value = braze_erasure_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -177,9 +192,9 @@ async def test_braze_erasure_request_task( merged_graph = braze_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite_name_and_email, graph, [braze_connection_config], identity_kwargs, @@ -207,7 +222,7 @@ async def test_braze_erasure_request_task( temp_masking = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = True - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite_name_and_email, graph, diff --git a/tests/ops/integration_tests/saas/test_datadog_task.py b/tests/ops/integration_tests/saas/test_datadog_task.py index b1c7125722..06027d49b5 100644 --- a/tests/ops/integration_tests/saas/test_datadog_task.py +++ b/tests/ops/integration_tests/saas/test_datadog_task.py @@ -1,13 +1,11 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task +from tests.conftest import access_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match +from tests.ops.test_helpers.cache_secrets_helper import clear_cache_identities @pytest.mark.integration_saas @@ -17,19 +15,24 @@ def test_datadog_connection_test(datadog_connection_config) -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_datadog_access_request_task_with_email( db, policy, + dsr_version, + request, + privacy_request, datadog_connection_config, datadog_dataset_config, datadog_identity_email, datadog_access_data, ) -> None: """Full access request based on the Datadog SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_datadog_access_request_task_{random.randint(0, 1000)}" - ) identity_attribute = "email" identity_value = datadog_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -39,7 +42,7 @@ async def test_datadog_access_request_task_with_email( dataset_name = datadog_connection_config.get_saas_config().fides_key merged_graph = datadog_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -77,20 +80,25 @@ async def test_datadog_access_request_task_with_email( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_datadog_access_request_task_with_phone_number( db, + dsr_version, + request, policy, + privacy_request, datadog_connection_config, datadog_dataset_config, - datadog_identity_email, datadog_identity_phone_number, datadog_access_data, ) -> None: """Full access request based on the Datadog SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + clear_cache_identities(privacy_request.id) - privacy_request = PrivacyRequest( - id=f"test_datadog_access_request_task_{random.randint(0, 1000)}" - ) identity_attribute = "phone_number" identity_value = datadog_identity_phone_number identity_kwargs = {identity_attribute: identity_value} @@ -100,7 +108,7 @@ async def test_datadog_access_request_task_with_phone_number( dataset_name = datadog_connection_config.get_saas_config().fides_key merged_graph = datadog_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, diff --git a/tests/ops/integration_tests/saas/test_delighted_task.py b/tests/ops/integration_tests/saas/test_delighted_task.py index 523f940804..5332bea1f5 100644 --- a/tests/ops/integration_tests/saas/test_delighted_task.py +++ b/tests/ops/integration_tests/saas/test_delighted_task.py @@ -1,14 +1,11 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import get_config +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match from tests.ops.test_helpers.saas_test_utils import poll_for_existence @@ -22,18 +19,23 @@ def test_delighted_connection_test(delighted_connection_config) -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_delighted_access_request_task( db, + dsr_version, + request, policy, + privacy_request, delighted_connection_config, delighted_dataset_config, delighted_identity_email, ) -> None: """Full access request based on the Delighted SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_delighted_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": delighted_identity_email}) privacy_request.cache_identity(identity) @@ -41,7 +43,7 @@ async def test_delighted_access_request_task( merged_graph = delighted_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -92,9 +94,16 @@ async def test_delighted_access_request_task( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_delighted_erasure_request_task( db, policy, + dsr_version, + request, + privacy_request, erasure_policy_string_rewrite, delighted_connection_config, delighted_dataset_config, @@ -103,15 +112,16 @@ async def test_delighted_erasure_request_task( delighted_test_client, ) -> None: """Full erasure request based on the Delighted SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) person = delighted_create_erasure_data masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False # Allow Delete - privacy_request = PrivacyRequest( - id=f"test_delighted_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": delighted_erasure_identity_email}) privacy_request.cache_identity(identity) @@ -119,9 +129,9 @@ async def test_delighted_erasure_request_task( merged_graph = delighted_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [delighted_connection_config], {"email": delighted_erasure_identity_email}, @@ -161,7 +171,7 @@ async def test_delighted_erasure_request_task( ], ) - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_domo_task.py b/tests/ops/integration_tests/saas/test_domo_task.py index 0ffcf8e385..48ddf36d3f 100644 --- a/tests/ops/integration_tests/saas/test_domo_task.py +++ b/tests/ops/integration_tests/saas/test_domo_task.py @@ -1,14 +1,11 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match @@ -21,18 +18,23 @@ def test_domo_connection_test(domo_connection_config) -> None: @pytest.mark.skip(reason="Pending account resolution") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_domo_access_request_task( policy, + dsr_version, + request, + privacy_request, domo_identity_email, domo_connection_config, domo_dataset_config, db, ) -> None: """Full access request based on the Domo SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_domo_access_request_task_{random.randint(0, 1000)}" - ) identity_kwargs = {"email": domo_identity_email} identity = Identity(**identity_kwargs) privacy_request.cache_identity(identity) @@ -41,7 +43,7 @@ async def test_domo_access_request_task( merged_graph = domo_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -73,9 +75,15 @@ async def test_domo_access_request_task( @pytest.mark.skip(reason="Pending account resolution") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_domo_erasure_request_task( db, - policy, + dsr_version, + request, + privacy_request, erasure_policy_string_rewrite_name_and_email, domo_erasure_identity_email, domo_create_erasure_data, @@ -84,10 +92,12 @@ async def test_domo_erasure_request_task( domo_dataset_config, ) -> None: """Full access request based on the Domo SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + user_id = domo_create_erasure_data - privacy_request = PrivacyRequest( - id=f"test_domo_erasure_request_task_{random.randint(0, 1000)}" - ) + privacy_request.policy_id = erasure_policy_string_rewrite_name_and_email.id + privacy_request.save(db) + identity_kwargs = {"email": domo_erasure_identity_email} identity = Identity(**identity_kwargs) privacy_request.cache_identity(identity) @@ -96,9 +106,9 @@ async def test_domo_erasure_request_task( merged_graph = domo_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite_name_and_email, graph, [domo_connection_config], identity_kwargs, @@ -128,7 +138,7 @@ async def test_domo_erasure_request_task( masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = True - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite_name_and_email, graph, diff --git a/tests/ops/integration_tests/saas/test_doordash_task.py b/tests/ops/integration_tests/saas/test_doordash_task.py index 754f9adf7a..39ce94fe3c 100644 --- a/tests/ops/integration_tests/saas/test_doordash_task.py +++ b/tests/ops/integration_tests/saas/test_doordash_task.py @@ -9,12 +9,20 @@ class TestDoordashConnector: def test_connection(self, doordash_runner: ConnectorRunner): doordash_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_access_request( self, + dsr_version, + request, doordash_runner: ConnectorRunner, policy: Policy, doordash_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + await doordash_runner.access_request( access_policy=policy, identities={"email": doordash_identity_email} ) diff --git a/tests/ops/integration_tests/saas/test_friendbuy_nextgen_task.py b/tests/ops/integration_tests/saas/test_friendbuy_nextgen_task.py index bad9969e11..48079ddb42 100644 --- a/tests/ops/integration_tests/saas/test_friendbuy_nextgen_task.py +++ b/tests/ops/integration_tests/saas/test_friendbuy_nextgen_task.py @@ -1,16 +1,14 @@ import logging -import random from time import sleep import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match logger = logging.getLogger(__name__) @@ -25,18 +23,24 @@ def test_friendbuy_nextgen_connection_test( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_friendbuy_nextgen_access_request_task( db, + dsr_version, + request, policy, + privacy_request, friendbuy_nextgen_connection_config, friendbuy_nextgen_dataset_config, friendbuy_nextgen_identity_email, connection_config, ) -> None: """Full access request based on the Friendbuy Nextgen Conversations SaaS config""" - privacy_request = PrivacyRequest( - id=f"test_friendbuy_nextgen_access_request_task_{random.randint(0, 1000)}" - ) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + identity_attribute = "email" identity_value = friendbuy_nextgen_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -47,7 +51,7 @@ async def test_friendbuy_nextgen_access_request_task( merged_graph = friendbuy_nextgen_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -79,9 +83,15 @@ async def test_friendbuy_nextgen_access_request_task( @pytest.mark.skip(reason="Temporarily disabled test") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_friendbuy_nextgen_erasure_request_task( db, - policy, + dsr_version, + request, + privacy_request, friendbuy_nextgen_connection_config, friendbuy_nextgen_dataset_config, connection_config, @@ -90,10 +100,11 @@ async def test_friendbuy_nextgen_erasure_request_task( friendbuy_nextgen_erasure_data, ) -> None: """Full erasure request based on the Friendbuy Nextgen SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) - privacy_request = PrivacyRequest( - id=f"test_friendbuy_nextgen_access_request_task_{random.randint(0, 1000)}" - ) identity_attribute = "email" identity_value = friendbuy_nextgen_erasure_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -107,9 +118,9 @@ async def test_friendbuy_nextgen_erasure_request_task( # Adding 30 seconds sleep because sometimes Friendbuy Nextgen system takes around 30 seconds for user to be available sleep(30) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [friendbuy_nextgen_connection_config, connection_config], {"email": friendbuy_nextgen_erasure_identity_email}, @@ -138,7 +149,7 @@ async def test_friendbuy_nextgen_erasure_request_task( temp_masking = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_friendbuy_task.py b/tests/ops/integration_tests/saas/test_friendbuy_task.py index 0291312469..4d18fd2c21 100644 --- a/tests/ops/integration_tests/saas/test_friendbuy_task.py +++ b/tests/ops/integration_tests/saas/test_friendbuy_task.py @@ -1,17 +1,15 @@ import logging -import random from time import sleep import pytest import requests from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match logger = logging.getLogger(__name__) @@ -28,9 +26,16 @@ def test_friendbuy_connection_test( @pytest.mark.skip(reason="No active account") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_friendbuy_access_request_task( db, + dsr_version, + request, policy, + privacy_request, friendbuy_connection_config, friendbuy_dataset_config, friendbuy_identity_email, @@ -39,9 +44,8 @@ async def test_friendbuy_access_request_task( friendbuy_postgres_db, ) -> None: """Full access request based on the Friendbuy Conversations SaaS config""" - privacy_request = PrivacyRequest( - id=f"test_friendbuy_access_request_task_{random.randint(0, 1000)}" - ) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + identity_attribute = "email" identity_value = friendbuy_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -52,7 +56,7 @@ async def test_friendbuy_access_request_task( merged_graph = friendbuy_dataset_config.get_graph() graph = DatasetGraph(*[merged_graph, friendbuy_postgres_dataset_config.get_graph()]) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -83,9 +87,15 @@ async def test_friendbuy_access_request_task( @pytest.mark.skip(reason="No active account") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_friendbuy_erasure_request_task( db, - policy, + dsr_version, + request, + privacy_request, friendbuy_connection_config, friendbuy_dataset_config, connection_config, @@ -96,10 +106,11 @@ async def test_friendbuy_erasure_request_task( friendbuy_postgres_erasure_db, ) -> None: """Full erasure request based on the Friendbuy SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) - privacy_request = PrivacyRequest( - id=f"test_friendbuy_access_request_task_{random.randint(0, 1000)}" - ) identity_attribute = "email" identity_value = friendbuy_erasure_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -110,9 +121,9 @@ async def test_friendbuy_erasure_request_task( merged_graph = friendbuy_dataset_config.get_graph() graph = DatasetGraph(*[merged_graph, friendbuy_postgres_dataset_config.get_graph()]) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [friendbuy_connection_config, connection_config], {"email": friendbuy_erasure_identity_email}, @@ -140,7 +151,7 @@ async def test_friendbuy_erasure_request_task( temp_masking = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_fullstory_task.py b/tests/ops/integration_tests/saas/test_fullstory_task.py index f172937198..a082e268e4 100644 --- a/tests/ops/integration_tests/saas/test_fullstory_task.py +++ b/tests/ops/integration_tests/saas/test_fullstory_task.py @@ -1,14 +1,11 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.fixtures.saas.fullstory_fixtures import FullstoryTestClient, user_updated from tests.ops.graph.graph_test_util import assert_rows_match from tests.ops.test_helpers.saas_test_utils import poll_for_existence @@ -25,9 +22,16 @@ def test_fullstory_connection_test( @pytest.mark.skip(reason="API keys are temporary for free accounts") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_fullstory_access_request_task( db, + dsr_version, + request, policy, + privacy_request, fullstory_connection_config, fullstory_dataset_config, fullstory_identity_email, @@ -36,9 +40,8 @@ async def test_fullstory_access_request_task( fullstory_postgres_db, ) -> None: """Full access request based on the Fullstory SaaS config""" - privacy_request = PrivacyRequest( - id=f"test_fullstory_access_request_task_{random.randint(0, 1000)}" - ) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + identity_attribute = "email" identity_value = fullstory_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -50,7 +53,7 @@ async def test_fullstory_access_request_task( merged_graph = fullstory_dataset_config.get_graph() graph = DatasetGraph(*[merged_graph, fullstory_postgres_dataset_config.get_graph()]) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -77,12 +80,18 @@ async def test_fullstory_access_request_task( @pytest.mark.skip(reason="API keys are temporary for free accounts") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_fullstory_erasure_request_task( db, - policy, + dsr_version, + request, fullstory_connection_config, fullstory_dataset_config, connection_config, + privacy_request, fullstory_postgres_erasure_db, fullstory_postgres_dataset_config, erasure_policy_string_rewrite, @@ -92,10 +101,8 @@ async def test_fullstory_erasure_request_task( fullstory_test_client: FullstoryTestClient, ) -> None: """Full erasure request based on the Fullstory SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_fullstory_access_request_task_{random.randint(0, 1000)}" - ) identity_attribute = "email" identity_value = fullstory_erasure_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -106,9 +113,9 @@ async def test_fullstory_erasure_request_task( merged_graph = fullstory_dataset_config.get_graph() graph = DatasetGraph(*[merged_graph, fullstory_postgres_dataset_config.get_graph()]) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [fullstory_connection_config, connection_config], {"email": fullstory_erasure_identity_email}, @@ -132,7 +139,7 @@ async def test_fullstory_erasure_request_task( temp_masking = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = True - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_gong_task.py b/tests/ops/integration_tests/saas/test_gong_task.py index 160d74e5a4..8fea525638 100644 --- a/tests/ops/integration_tests/saas/test_gong_task.py +++ b/tests/ops/integration_tests/saas/test_gong_task.py @@ -9,13 +9,21 @@ class TestGongConnector: def test_connection(self, gong_runner: ConnectorRunner): gong_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_access_request( self, + dsr_version, + request, gong_runner: ConnectorRunner, policy, gong_identity_email: str, gong_identity_name: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + access_results = await gong_runner.access_request( access_policy=policy, identities={"email": gong_identity_email} ) @@ -27,8 +35,14 @@ async def test_access_request( for obj in objects: assert obj["fields"][0] == {"name": "fullName", "value": gong_identity_name} + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request( self, + dsr_version, + request, gong_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, diff --git a/tests/ops/integration_tests/saas/test_google_analytics_task.py b/tests/ops/integration_tests/saas/test_google_analytics_task.py index c3aa89b616..d6e13ad5bd 100644 --- a/tests/ops/integration_tests/saas/test_google_analytics_task.py +++ b/tests/ops/integration_tests/saas/test_google_analytics_task.py @@ -1,22 +1,16 @@ from unittest import mock -from uuid import uuid4 import pytest from fides.api.models.policy import ActionType -from fides.api.models.privacy_request import ( - ExecutionLog, - ExecutionLogStatus, - PrivacyRequest, - PrivacyRequestStatus, -) +from fides.api.models.privacy_request import ExecutionLog, ExecutionLogStatus from fides.api.schemas.redis_cache import Identity from fides.api.schemas.saas.shared_schemas import SaaSRequestParams from fides.api.service.connectors import get_connector from fides.api.service.privacy_request.request_runner_service import ( build_consent_dataset_graph, ) -from fides.api.task import graph_task +from tests.conftest import consent_runner_tester @pytest.mark.integration_saas @@ -30,26 +24,35 @@ def test_google_analytics_connection_test( @pytest.mark.integration_saas @pytest.mark.asyncio @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") +@pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_google_analytics_consent_request_task_old_workflow( db, consent_policy, google_analytics_connection_config, google_analytics_dataset_config, google_analytics_client_id, + privacy_request, + dsr_version, + request, ) -> None: """Full consent request based on the Google Analytics SaaS config""" + privacy_request.consent_preferences = [ + {"data_use": "marketing.advertising", "opt_in": False} + ] + privacy_request.save(db) - privacy_request = PrivacyRequest( - id=str(uuid4()), - consent_preferences=[{"data_use": "marketing.advertising", "opt_in": False}], - ) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 identity = Identity(**{"ga_client_id": google_analytics_client_id}) privacy_request.cache_identity(identity) dataset_name = "google_analytics_instance" - v = await graph_task.run_consent_request( + v = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([google_analytics_dataset_config]), @@ -88,6 +91,10 @@ async def test_google_analytics_consent_request_task_old_workflow( @pytest.mark.asyncio @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_google_analytics_consent_prepared_requests_old_workflow( mocked_client_send, db, @@ -95,18 +102,22 @@ async def test_google_analytics_consent_prepared_requests_old_workflow( google_analytics_connection_config, google_analytics_dataset_config, google_analytics_client_id, + privacy_request, + dsr_version, + request, ) -> None: """Assert attributes of the PreparedRequest created by the client for running the consent request""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=str(uuid4()), - consent_preferences=[{"data_use": "marketing.advertising", "opt_in": False}], - ) + privacy_request.consent_preferences = [ + {"data_use": "marketing.advertising", "opt_in": False} + ] + privacy_request.save(db) identity = Identity(**{"ga_client_id": google_analytics_client_id}) privacy_request.cache_identity(identity) - await graph_task.run_consent_request( + consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([google_analytics_dataset_config]), @@ -131,26 +142,35 @@ async def test_google_analytics_consent_prepared_requests_old_workflow( @pytest.mark.asyncio @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_google_analytics_no_ga_client_id_old_workflow( mocked_client_send, db, consent_policy, google_analytics_connection_config, google_analytics_dataset_config, + privacy_request, + dsr_version, + request, ) -> None: """Test that the google analytics connector does not fail if there is no ga_client_id We skip the request because it is marked as skip_missing_param_values=True. We won't always have this piece of identity data. """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.consent_preferences = [ + {"data_use": "marketing.advertising", "opt_in": False} + ] + privacy_request.save(db) - privacy_request = PrivacyRequest( - id=str(uuid4()), - consent_preferences=[{"data_use": "marketing.advertising", "opt_in": False}], - ) dataset_name = "google_analytics_instance" - v = await graph_task.run_consent_request( + v = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([google_analytics_dataset_config]), @@ -168,6 +188,10 @@ async def test_google_analytics_no_ga_client_id_old_workflow( @pytest.mark.integration_saas @pytest.mark.asyncio @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_google_analytics_no_ga_client_id_new_workflow( mocked_client_send, db, @@ -175,18 +199,20 @@ async def test_google_analytics_no_ga_client_id_new_workflow( google_analytics_connection_config_without_secrets, google_analytics_dataset_config_no_secrets, privacy_preference_history, + privacy_request, + dsr_version, + request, ) -> None: """Test google analytics connector skips instead of fails if identity missing.""" - privacy_request = PrivacyRequest( - id=str(uuid4()), status=PrivacyRequestStatus.pending - ) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + privacy_request.save(db) privacy_preference_history.privacy_request_id = privacy_request.id privacy_preference_history.save(db=db) dataset_name = "google_analytics_instance" - v = await graph_task.run_consent_request( + v = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([google_analytics_dataset_config_no_secrets]), @@ -232,6 +258,10 @@ async def test_google_analytics_no_ga_client_id_new_workflow( @pytest.mark.integration_saas @pytest.mark.asyncio @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_google_analytics_consent_request_task_new_workflow( db, consent_policy, @@ -241,15 +271,17 @@ async def test_google_analytics_consent_request_task_new_workflow( privacy_preference_history, privacy_preference_history_us_ca_provide, system, + privacy_request, + dsr_version, + request, ) -> None: """Full consent request based on the Google Analytics SaaS config for the new workflow where we save preferences w.r.t. privacy notices""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + google_analytics_connection_config.system_id = system.id google_analytics_connection_config.save(db) - privacy_request = PrivacyRequest( - id=str(uuid4()), status=PrivacyRequestStatus.pending - ) privacy_request.save(db) # This preference matches on data use privacy_preference_history.privacy_request_id = privacy_request.id @@ -264,7 +296,7 @@ async def test_google_analytics_consent_request_task_new_workflow( dataset_name = "google_analytics_instance" - v = await graph_task.run_consent_request( + v = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([google_analytics_dataset_config]), @@ -317,6 +349,10 @@ async def test_google_analytics_consent_request_task_new_workflow( @pytest.mark.asyncio @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_google_analytics_consent_request_task_new_errored_workflow( mocked_client_send, db, @@ -327,22 +363,22 @@ async def test_google_analytics_consent_request_task_new_errored_workflow( privacy_preference_history, privacy_preference_history_us_ca_provide, system, + privacy_request, + dsr_version, + request, ) -> None: """Testing errored Google Analytics SaaS config for the new workflow where we save preferences w.r.t. privacy notices Assert logging created appropriately """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + mocked_client_send.side_effect = Exception("KeyError") google_analytics_connection_config.system_id = system.id google_analytics_connection_config.save(db) - privacy_request = PrivacyRequest( - id=str(uuid4()), status=PrivacyRequestStatus.pending - ) - privacy_request.save(db) - # This preference matches on data use privacy_preference_history.privacy_request_id = privacy_request.id privacy_preference_history.save(db=db) @@ -355,7 +391,7 @@ async def test_google_analytics_consent_request_task_new_errored_workflow( privacy_request.cache_identity(identity) with pytest.raises(Exception): - await graph_task.run_consent_request( + consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([google_analytics_dataset_config]), @@ -397,6 +433,10 @@ async def test_google_analytics_consent_request_task_new_errored_workflow( @pytest.mark.asyncio @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_google_analytics_consent_prepared_requests_new_workflow( mocked_client_send, db, @@ -405,13 +445,15 @@ async def test_google_analytics_consent_prepared_requests_new_workflow( google_analytics_dataset_config, google_analytics_client_id, privacy_preference_history, + privacy_request, + dsr_version, + request, ) -> None: """Assert attributes of the PreparedRequest created by the client for running the consent request for the new workflow """ - privacy_request = PrivacyRequest( - id=str(uuid4()), status=PrivacyRequestStatus.pending - ) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + privacy_request.save(db) privacy_preference_history.privacy_request_id = privacy_request.id privacy_preference_history.save(db=db) @@ -419,7 +461,7 @@ async def test_google_analytics_consent_prepared_requests_new_workflow( identity = Identity(**{"ga_client_id": google_analytics_client_id}) privacy_request.cache_identity(identity) - await graph_task.run_consent_request( + consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([google_analytics_dataset_config]), diff --git a/tests/ops/integration_tests/saas/test_gorgias_task.py b/tests/ops/integration_tests/saas/test_gorgias_task.py index 2591b896cd..3ecab7cdda 100644 --- a/tests/ops/integration_tests/saas/test_gorgias_task.py +++ b/tests/ops/integration_tests/saas/test_gorgias_task.py @@ -1,15 +1,12 @@ -import random - import pytest import requests from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import get_config +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match CONFIG = get_config() @@ -22,18 +19,23 @@ def test_gorgias_connection_test(gorgias_connection_config) -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_gorgias_access_request_task( db, policy, gorgias_connection_config, gorgias_dataset_config, gorgias_identity_email, + privacy_request, + dsr_version, + request, ) -> None: """Full access request based on the Gorgias SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_gorgias_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": gorgias_identity_email}) privacy_request.cache_identity(identity) @@ -41,7 +43,7 @@ async def test_gorgias_access_request_task( merged_graph = gorgias_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -157,23 +159,30 @@ async def test_gorgias_access_request_task( @pytest.mark.skip(reason="Pending account resolution") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_gorgias_erasure_request_task( db, - policy, erasure_policy_string_rewrite, gorgias_connection_config, gorgias_dataset_config, gorgias_erasure_identity_email, + privacy_request, + request, + dsr_version, gorgias_create_erasure_data, ) -> None: """Full erasure request based on the Gorgias SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = True - privacy_request = PrivacyRequest( - id=f"test_gorgias_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": gorgias_erasure_identity_email}) privacy_request.cache_identity(identity) @@ -181,9 +190,9 @@ async def test_gorgias_erasure_request_task( merged_graph = gorgias_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [gorgias_connection_config], {"email": gorgias_erasure_identity_email}, @@ -286,7 +295,7 @@ async def test_gorgias_erasure_request_task( ], ) - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_heap_task.py b/tests/ops/integration_tests/saas/test_heap_task.py index 5e47a67463..36d6abdd89 100644 --- a/tests/ops/integration_tests/saas/test_heap_task.py +++ b/tests/ops/integration_tests/saas/test_heap_task.py @@ -10,13 +10,21 @@ class TestHeapConnector: def test_connection(self, heap_runner: ConnectorRunner): heap_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request( self, + dsr_version, + request, heap_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, heap_erasure_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + ( access_results, erasure_results, diff --git a/tests/ops/integration_tests/saas/test_hubspot_task.py b/tests/ops/integration_tests/saas/test_hubspot_task.py index 35f2a9dd3a..10fa72c346 100644 --- a/tests/ops/integration_tests/saas/test_hubspot_task.py +++ b/tests/ops/integration_tests/saas/test_hubspot_task.py @@ -1,15 +1,12 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.filter_results import filter_data_categories from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.fixtures.saas.hubspot_fixtures import HubspotTestClient, user_exists from tests.ops.graph.graph_test_util import assert_rows_match from tests.ops.test_helpers.saas_test_utils import poll_for_existence @@ -22,18 +19,23 @@ def test_hubspot_connection_test(connection_config_hubspot) -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_hubspot_access_request_task( db, + dsr_version, + request, policy, connection_config_hubspot, dataset_config_hubspot, hubspot_identity_email, + privacy_request, ) -> None: """Full access request based on the Hubspot SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_hubspot_access_request_task_{random.randint(0, 1000)}" - ) identity_attribute = "email" identity_value = hubspot_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -44,7 +46,7 @@ async def test_hubspot_access_request_task( merged_graph = dataset_config_hubspot.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -124,9 +126,12 @@ async def test_hubspot_access_request_task( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.usefixtures( + "use_dsr_3_0" +) # Only testing on DSR 3.0 not 2.0 - because of fixtures taking too long to settle down async def test_hubspot_erasure_request_task( db, - policy, + privacy_request, erasure_policy_string_rewrite_name_and_email, connection_config_hubspot, dataset_config_hubspot, @@ -135,10 +140,11 @@ async def test_hubspot_erasure_request_task( hubspot_test_client: HubspotTestClient, ) -> None: """Full erasure request based on the Hubspot SaaS config""" + + privacy_request.policy_id = erasure_policy_string_rewrite_name_and_email.id + privacy_request.save(db) contact_id, user_id = hubspot_erasure_data - privacy_request = PrivacyRequest( - id=f"test_hubspot_erasure_request_task_{random.randint(0, 1000)}" - ) + identity_attribute = "email" identity_kwargs = {identity_attribute: (hubspot_erasure_identity_email)} identity = Identity(**identity_kwargs) @@ -148,8 +154,13 @@ async def test_hubspot_erasure_request_task( merged_graph = dataset_config_hubspot.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( - privacy_request, policy, graph, [connection_config_hubspot], identity_kwargs, db + v = access_runner_tester( + privacy_request, + erasure_policy_string_rewrite_name_and_email, + graph, + [connection_config_hubspot], + identity_kwargs, + db, ) assert_rows_match( @@ -165,7 +176,7 @@ async def test_hubspot_erasure_request_task( temp_masking = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False # Allow delete - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite_name_and_email, graph, diff --git a/tests/ops/integration_tests/saas/test_iterable_task.py b/tests/ops/integration_tests/saas/test_iterable_task.py index e9a4bb8c22..5d64fed5ab 100644 --- a/tests/ops/integration_tests/saas/test_iterable_task.py +++ b/tests/ops/integration_tests/saas/test_iterable_task.py @@ -9,13 +9,21 @@ class TestIterableConnector: def test_connection(self, iterable_runner: ConnectorRunner): iterable_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request( self, + dsr_version, + request, iterable_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, iterable_erasure_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + ( _, erasure_results, diff --git a/tests/ops/integration_tests/saas/test_iterate_task.py b/tests/ops/integration_tests/saas/test_iterate_task.py index 67d7564262..5de9db1c1b 100644 --- a/tests/ops/integration_tests/saas/test_iterate_task.py +++ b/tests/ops/integration_tests/saas/test_iterate_task.py @@ -9,9 +9,20 @@ class TestIterateConnector: def test_connection(self, iterate_runner: ConnectorRunner): iterate_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_access_request( - self, iterate_runner: ConnectorRunner, policy, iterate_identity_email: str + self, + iterate_runner: ConnectorRunner, + policy, + iterate_identity_email: str, + dsr_version, + request, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + access_results = await iterate_runner.access_request( access_policy=policy, identities={"email": iterate_identity_email} ) @@ -21,13 +32,21 @@ async def test_access_request( ) @pytest.mark.skip(reason="Unable to create erasure data") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_strict_erasure_request( self, + dsr_version, + request, iterate_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, iterate_erasure_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + ( access_results, erasure_results, diff --git a/tests/ops/integration_tests/saas/test_jira_task.py b/tests/ops/integration_tests/saas/test_jira_task.py index b12b6c6414..d213e2f6ef 100644 --- a/tests/ops/integration_tests/saas/test_jira_task.py +++ b/tests/ops/integration_tests/saas/test_jira_task.py @@ -1,16 +1,14 @@ -import random from time import sleep import pytest import requests from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import get_config +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match CONFIG = get_config() @@ -25,19 +23,24 @@ def test_jira_connection_test(jira_connection_config) -> None: @pytest.mark.skip(reason="No active account") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_jira_access_request_task( db, + dsr_version, + request, policy, + privacy_request, jira_connection_config, jira_dataset_config, jira_identity_email, # jira_user_name, ) -> None: """Full access request based on the Jira SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_jira_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": jira_identity_email}) privacy_request.cache_identity(identity) @@ -45,7 +48,7 @@ async def test_jira_access_request_task( merged_graph = jira_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -76,9 +79,15 @@ async def test_jira_access_request_task( @pytest.mark.skip(reason="No active account") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_jira_erasure_request_task( db, - policy, + dsr_version, + request, + privacy_request, erasure_policy_string_rewrite, jira_connection_config, jira_dataset_config, @@ -86,13 +95,14 @@ async def test_jira_erasure_request_task( jira_create_erasure_data, ) -> None: """Full erasure request based on the Jira SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False # Allow Delete - privacy_request = PrivacyRequest( - id=f"test_jira_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": jira_erasure_identity_email}) privacy_request.cache_identity(identity) @@ -100,9 +110,9 @@ async def test_jira_erasure_request_task( merged_graph = jira_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [jira_connection_config], {"email": jira_erasure_identity_email}, @@ -124,7 +134,7 @@ async def test_jira_erasure_request_task( ], ) - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_klaviyo_task.py b/tests/ops/integration_tests/saas/test_klaviyo_task.py index 472550dcdd..b513887bb7 100644 --- a/tests/ops/integration_tests/saas/test_klaviyo_task.py +++ b/tests/ops/integration_tests/saas/test_klaviyo_task.py @@ -9,12 +9,20 @@ class TestKlaviyoConnector: def test_connection(self, klaviyo_runner: ConnectorRunner): klaviyo_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_access_request( self, + dsr_version, + request, klaviyo_runner: ConnectorRunner, policy: Policy, klaviyo_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + access_results = await klaviyo_runner.access_request( access_policy=policy, identities={"email": klaviyo_identity_email} ) @@ -25,14 +33,22 @@ async def test_access_request( == klaviyo_identity_email ) + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request( self, + request, + dsr_version, klaviyo_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, klaviyo_erasure_identity_email: str, klaviyo_erasure_data, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + ( _, erasure_results, @@ -57,12 +73,20 @@ async def test_old_consent_request( ) assert consent_results == {"opt_in": True, "opt_out": True} + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_new_consent_request( self, + dsr_version, + request, klaviyo_runner: ConnectorRunner, consent_policy: Policy, klaviyo_erasure_identity_email, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + consent_results = await klaviyo_runner.new_consent_request( consent_policy, {"email": klaviyo_erasure_identity_email} ) diff --git a/tests/ops/integration_tests/saas/test_kustomer_task.py b/tests/ops/integration_tests/saas/test_kustomer_task.py index 0814487a63..c3c98c344b 100644 --- a/tests/ops/integration_tests/saas/test_kustomer_task.py +++ b/tests/ops/integration_tests/saas/test_kustomer_task.py @@ -1,16 +1,14 @@ -import random - import pytest import requests from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import get_config +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match +from tests.ops.test_helpers.cache_secrets_helper import clear_cache_identities CONFIG = get_config() @@ -22,18 +20,23 @@ def test_kustomer_connection_test(kustomer_connection_config) -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_kustomer_access_request_task_with_email( db, policy, kustomer_connection_config, kustomer_dataset_config, kustomer_identity_email, + privacy_request, + request, + dsr_version, ) -> None: """Full access request based on the Kustomer SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_kustomer_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": kustomer_identity_email}) privacy_request.cache_identity(identity) @@ -41,7 +44,7 @@ async def test_kustomer_access_request_task_with_email( merged_graph = kustomer_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -65,18 +68,23 @@ async def test_kustomer_access_request_task_with_email( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_kustomer_access_request_task_with_non_existent_email( db, policy, kustomer_connection_config, kustomer_dataset_config, + privacy_request, + dsr_version, + request, kustomer_non_existent_identity_email, ) -> None: """Access request that returns a 404 but succeeds""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_kustomer_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": kustomer_non_existent_identity_email}) privacy_request.cache_identity(identity) @@ -84,7 +92,7 @@ async def test_kustomer_access_request_task_with_non_existent_email( merged_graph = kustomer_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -99,18 +107,24 @@ async def test_kustomer_access_request_task_with_non_existent_email( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_kustomer_access_request_task_with_phone_number( db, policy, kustomer_connection_config, kustomer_dataset_config, kustomer_identity_phone_number, + privacy_request, + dsr_version, + request, ) -> None: """Full access request based on the Kustomer SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + clear_cache_identities(privacy_request.id) - privacy_request = PrivacyRequest( - id=f"test_kustomer_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"phone_number": kustomer_identity_phone_number}) privacy_request.cache_identity(identity) @@ -118,7 +132,7 @@ async def test_kustomer_access_request_task_with_phone_number( merged_graph = kustomer_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -142,23 +156,30 @@ async def test_kustomer_access_request_task_with_phone_number( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_kustomer_erasure_request_task( db, - policy, erasure_policy_string_rewrite, kustomer_connection_config, kustomer_dataset_config, kustomer_erasure_identity_email, kustomer_create_erasure_data, + privacy_request, + dsr_version, + request, ) -> None: """Full erasure request based on the Kustomer SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False # Allow Delete - privacy_request = PrivacyRequest( - id=f"test_kustomer_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": kustomer_erasure_identity_email}) privacy_request.cache_identity(identity) @@ -166,9 +187,9 @@ async def test_kustomer_erasure_request_task( merged_graph = kustomer_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [kustomer_connection_config], {"email": kustomer_erasure_identity_email}, @@ -181,7 +202,7 @@ async def test_kustomer_erasure_request_task( keys=["type", "id", "attributes", "relationships", "links"], ) - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, @@ -211,23 +232,30 @@ async def test_kustomer_erasure_request_task( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_kustomer_erasure_request_task_non_existent_email( db, - policy, + privacy_request, erasure_policy_string_rewrite, kustomer_connection_config, kustomer_dataset_config, kustomer_non_existent_identity_email, kustomer_create_erasure_data, + dsr_version, + request, ) -> None: """Full erasure request based on the Kustomer SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False # Allow Delete - privacy_request = PrivacyRequest( - id=f"test_kustomer_erasure_request_task_non_existent_email{random.randint(0, 1000)}" - ) + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) + identity = Identity(**{"email": kustomer_non_existent_identity_email}) privacy_request.cache_identity(identity) @@ -235,7 +263,16 @@ async def test_kustomer_erasure_request_task_non_existent_email( merged_graph = kustomer_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - x = await graph_task.run_erasure( + v = access_runner_tester( + privacy_request, + erasure_policy_string_rewrite, + graph, + [kustomer_connection_config], + {"email": kustomer_non_existent_identity_email}, + db, + ) + + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_mailchimp_task.py b/tests/ops/integration_tests/saas/test_mailchimp_task.py index f5bf884d94..57228a3fd5 100644 --- a/tests/ops/integration_tests/saas/test_mailchimp_task.py +++ b/tests/ops/integration_tests/saas/test_mailchimp_task.py @@ -1,14 +1,11 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import ExecutionLog, PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures -from tests.ops.graph.graph_test_util import assert_rows_match, records_matching_fields +from tests.conftest import access_runner_tester, erasure_runner_tester +from tests.ops.graph.graph_test_util import assert_rows_match @pytest.mark.integration_saas @@ -18,18 +15,23 @@ def test_mailchimp_connection_test(mailchimp_connection_config) -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_mailchimp_access_request_task( db, policy, mailchimp_connection_config, mailchimp_dataset_config, mailchimp_identity_email, + privacy_request, + dsr_version, + request, ) -> None: """Full access request based on the Mailchimp SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_mailchimp_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": mailchimp_identity_email}) privacy_request.cache_identity(identity) @@ -37,7 +39,7 @@ async def test_mailchimp_access_request_task( merged_graph = mailchimp_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -95,20 +97,27 @@ async def test_mailchimp_access_request_task( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_mailchimp_erasure_request_task( db, - policy, + privacy_request, erasure_policy_string_rewrite, mailchimp_connection_config, mailchimp_dataset_config, mailchimp_identity_email, reset_mailchimp_data, + dsr_version, + request, ) -> None: """Full erasure request based on the Mailchimp SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) - privacy_request = PrivacyRequest( - id=f"test_mailchimp_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": mailchimp_identity_email}) privacy_request.cache_identity(identity) @@ -116,16 +125,16 @@ async def test_mailchimp_erasure_request_task( merged_graph = mailchimp_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - await graph_task.run_access_request( + access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [mailchimp_connection_config], {"email": mailchimp_identity_email}, db, ) - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_mailchimp_transactional_task.py b/tests/ops/integration_tests/saas/test_mailchimp_transactional_task.py index 3bc4405411..458d33d356 100644 --- a/tests/ops/integration_tests/saas/test_mailchimp_transactional_task.py +++ b/tests/ops/integration_tests/saas/test_mailchimp_transactional_task.py @@ -1,6 +1,5 @@ import json from unittest import mock -from uuid import uuid4 import pytest @@ -8,8 +7,7 @@ from fides.api.models.privacy_request import ( ExecutionLog, ExecutionLogStatus, - PrivacyRequest, - PrivacyRequestStatus, + RequestTask, ) from fides.api.schemas.redis_cache import Identity from fides.api.schemas.saas.saas_config import SaaSRequest @@ -18,7 +16,7 @@ from fides.api.service.privacy_request.request_runner_service import ( build_consent_dataset_graph, ) -from fides.api.task import graph_task +from tests.conftest import consent_runner_tester @pytest.mark.integration_saas @@ -31,26 +29,34 @@ def test_mailchimp_transactional_connection_test( @pytest.mark.integration_saas @pytest.mark.asyncio @pytest.mark.usefixtures("reset_mailchimp_transactional_data") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_mailchimp_transactional_consent_request_task_old_workflow( db, consent_policy, mailchimp_transactional_connection_config, mailchimp_transactional_dataset_config, mailchimp_transactional_identity_email, + privacy_request, + dsr_version, + request, ) -> None: """Full consent request based on the Mailchimp Transactional (Mandrill) SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=str(uuid4()), - consent_preferences=[{"data_use": "marketing.advertising", "opt_in": False}], - ) + privacy_request.consent_preferences = [ + {"data_use": "marketing.advertising", "opt_in": False} + ] + privacy_request.save(db) identity = Identity(**{"email": mailchimp_transactional_identity_email}) privacy_request.cache_identity(identity) dataset_name = "mailchimp_transactional_instance" - v = await graph_task.run_consent_request( + v = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([mailchimp_transactional_dataset_config]), @@ -113,6 +119,10 @@ async def test_mailchimp_transactional_consent_request_task_old_workflow( @pytest.mark.integration_saas @pytest.mark.asyncio @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_mailchimp_transactional_consent_prepared_requests_old_workflow( mocked_client_send, db, @@ -120,18 +130,22 @@ async def test_mailchimp_transactional_consent_prepared_requests_old_workflow( mailchimp_transactional_connection_config, mailchimp_transactional_dataset_config, mailchimp_transactional_identity_email, + privacy_request, + dsr_version, + request, ) -> None: """Assert attributes of the PreparedRequest created by the client for running the consent request""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=str(uuid4()), - consent_preferences=[{"data_use": "marketing.advertising", "opt_in": False}], - ) + privacy_request.consent_preferences = [ + {"data_use": "marketing.advertising", "opt_in": False} + ] + privacy_request.save(db) identity = Identity(**{"email": mailchimp_transactional_identity_email}) privacy_request.cache_identity(identity) - await graph_task.run_consent_request( + consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([mailchimp_transactional_dataset_config]), @@ -152,6 +166,10 @@ async def test_mailchimp_transactional_consent_prepared_requests_old_workflow( @pytest.mark.integration_saas @pytest.mark.asyncio @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_no_prepared_request_fired_without_consent_preferences_old_workflow( mocked_client_send, db, @@ -159,17 +177,17 @@ async def test_no_prepared_request_fired_without_consent_preferences_old_workflo mailchimp_transactional_connection_config, mailchimp_transactional_dataset_config, mailchimp_transactional_identity_email, + privacy_request, + dsr_version, + request, ) -> None: """Assert attributes of the PreparedRequest created by the client for running the consent request""" - - privacy_request = PrivacyRequest( - id=str(uuid4()), - ) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 identity = Identity(**{"email": mailchimp_transactional_identity_email}) privacy_request.cache_identity(identity) - await graph_task.run_consent_request( + consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([mailchimp_transactional_dataset_config]), @@ -184,6 +202,10 @@ async def test_no_prepared_request_fired_without_consent_preferences_old_workflo @pytest.mark.integration_saas @pytest.mark.asyncio @pytest.mark.usefixtures("reset_mailchimp_transactional_data") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_mailchimp_transactional_consent_request_task_new_workflow( db, consent_policy, @@ -192,20 +214,21 @@ async def test_mailchimp_transactional_consent_request_task_new_workflow( mailchimp_transactional_identity_email, privacy_preference_history, privacy_preference_history_us_ca_provide, + privacy_request, system, + dsr_version, + request, ) -> None: """Full consent request based on the Mailchimp Transactional (Mandrill) SaaS config with new workflow where preferences are saved w.r.t privacy notices Assert that only relevant preferences get "complete" log, others get "skipped" """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + mailchimp_transactional_connection_config.system_id = system.id mailchimp_transactional_connection_config.save(db) - privacy_request = PrivacyRequest( - id=str(uuid4()), status=PrivacyRequestStatus.pending - ) - privacy_request.save(db) # This preference is relevant on data use privacy_preference_history.privacy_request_id = privacy_request.id privacy_preference_history.save(db=db) @@ -219,7 +242,7 @@ async def test_mailchimp_transactional_consent_request_task_new_workflow( dataset_name = "mailchimp_transactional_instance" - v = await graph_task.run_consent_request( + v = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([mailchimp_transactional_dataset_config]), @@ -297,6 +320,10 @@ async def test_mailchimp_transactional_consent_request_task_new_workflow( @pytest.mark.integration_saas @pytest.mark.asyncio @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_mailchimp_transactional_consent_prepared_requests_new_workflow( mocked_client_send, db, @@ -306,8 +333,11 @@ async def test_mailchimp_transactional_consent_prepared_requests_new_workflow( mailchimp_transactional_identity_email, privacy_preference_history, privacy_request_with_consent_policy, + dsr_version, + request, ) -> None: """Assert attributes of the PreparedRequest created by the client for running the consent request""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 privacy_preference_history.privacy_request_id = ( privacy_request_with_consent_policy.id @@ -317,7 +347,7 @@ async def test_mailchimp_transactional_consent_prepared_requests_new_workflow( identity = Identity(**{"email": mailchimp_transactional_identity_email}) privacy_request_with_consent_policy.cache_identity(identity) - await graph_task.run_consent_request( + consent_runner_tester( privacy_request_with_consent_policy, consent_policy, build_consent_dataset_graph([mailchimp_transactional_dataset_config]), @@ -345,6 +375,10 @@ async def test_mailchimp_transactional_consent_prepared_requests_new_workflow( @pytest.mark.integration_saas @pytest.mark.asyncio @pytest.mark.usefixtures("reset_mailchimp_transactional_data") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_mailchimp_transactional_consent_request_task_new_workflow_skipped( db, consent_policy, @@ -353,15 +387,16 @@ async def test_mailchimp_transactional_consent_request_task_new_workflow_skipped mailchimp_transactional_identity_email, system, privacy_preference_history_us_ca_provide, + privacy_request, + dsr_version, + request, ) -> None: """Test privacy notice/data use system mismatch causes the request to be skipped""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + mailchimp_transactional_connection_config.system_id = system.id mailchimp_transactional_connection_config.save(db) - privacy_request = PrivacyRequest( - id=str(uuid4()), status=PrivacyRequestStatus.pending - ) - privacy_request.save(db) privacy_preference_history_us_ca_provide.privacy_request_id = privacy_request.id privacy_preference_history_us_ca_provide.save(db=db) @@ -370,7 +405,7 @@ async def test_mailchimp_transactional_consent_request_task_new_workflow_skipped dataset_name = "mailchimp_transactional_instance" - v = await graph_task.run_consent_request( + v = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([mailchimp_transactional_dataset_config]), @@ -378,7 +413,6 @@ async def test_mailchimp_transactional_consent_request_task_new_workflow_skipped {"email": mailchimp_transactional_identity_email}, db, ) - assert v == { f"{dataset_name}:{dataset_name}": False }, "graph has one node, and request skipped" @@ -416,6 +450,10 @@ async def test_mailchimp_transactional_consent_request_task_new_workflow_skipped @pytest.mark.asyncio @pytest.mark.usefixtures("reset_mailchimp_transactional_data") @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_mailchimp_transactional_consent_request_task_error( mocked_client_send, db, @@ -426,18 +464,19 @@ async def test_mailchimp_transactional_consent_request_task_error( system, privacy_preference_history, privacy_preference_history_us_ca_provide, + privacy_request, + dsr_version, + request, ) -> None: """Assert logging is correctly created for privacy preferences on errored request Assert case when some privacy preferences were relevant but not all. """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + mocked_client_send.side_effect = Exception("KeyError") mailchimp_transactional_connection_config.system_id = system.id mailchimp_transactional_connection_config.save(db) - privacy_request = PrivacyRequest( - id=str(uuid4()), status=PrivacyRequestStatus.pending - ) - privacy_request.save(db) # This preference matches on data use privacy_preference_history.privacy_request_id = privacy_request.id privacy_preference_history.save(db=db) @@ -451,8 +490,18 @@ async def test_mailchimp_transactional_consent_request_task_error( dataset_name = "mailchimp_transactional_instance" - with pytest.raises(Exception): - await graph_task.run_consent_request( + if dsr_version == "use_dsr_2_0": + with pytest.raises(Exception): + consent_runner_tester( + privacy_request, + consent_policy, + build_consent_dataset_graph([mailchimp_transactional_dataset_config]), + [mailchimp_transactional_connection_config], + {"email": mailchimp_transactional_identity_email}, + db, + ) + else: + consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([mailchimp_transactional_dataset_config]), @@ -460,6 +509,16 @@ async def test_mailchimp_transactional_consent_request_task_error( {"email": mailchimp_transactional_identity_email}, db, ) + rt = privacy_request.consent_tasks.filter( + RequestTask.collection_address + == "mailchimp_transactional_instance:mailchimp_transactional_instance" + ).first() + assert rt.status == ExecutionLogStatus.error # Matches status of Execution Log + terminator_task = privacy_request.get_terminate_task_by_action( + ActionType.consent + ) + # Terminator task was also marked "errored" + assert terminator_task.status == ExecutionLogStatus.error execution_logs = db.query(ExecutionLog).filter_by( privacy_request_id=privacy_request.id diff --git a/tests/ops/integration_tests/saas/test_onesignal_task.py b/tests/ops/integration_tests/saas/test_onesignal_task.py index 4da38ba786..42aed62d66 100644 --- a/tests/ops/integration_tests/saas/test_onesignal_task.py +++ b/tests/ops/integration_tests/saas/test_onesignal_task.py @@ -10,13 +10,25 @@ class TestOneSignalConnector: def test_connection(self, onesignal_runner: ConnectorRunner): onesignal_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_access_request( - self, onesignal_runner: ConnectorRunner, policy, onesignal_identity_email: str + self, + onesignal_runner: ConnectorRunner, + policy, + dsr_version, + request, + onesignal_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + access_results = await onesignal_runner.access_request( access_policy=policy, identities={"email": onesignal_identity_email} ) + @pytest.mark.usefixtures("use_dsr_3_0") async def test_non_strict_erasure_request( self, onesignal_runner: ConnectorRunner, @@ -26,6 +38,11 @@ async def test_non_strict_erasure_request( onesignal_erasure_data, onesignal_client, ): + """ + Testing this just on one scheduler: DSR 3.0 due to issues with data fixtures not settling down in time, + {'success': False, 'errors': ['A player with that identifier (XXXXX) already exists. Usually this is caused by calling the same update for multiple player ids. Please double check your update calls and try again']} + """ + player_id = onesignal_erasure_data ( access_results, diff --git a/tests/ops/integration_tests/saas/test_oracle_responsys_task.py b/tests/ops/integration_tests/saas/test_oracle_responsys_task.py index 52758ba8ac..b1db9b6484 100644 --- a/tests/ops/integration_tests/saas/test_oracle_responsys_task.py +++ b/tests/ops/integration_tests/saas/test_oracle_responsys_task.py @@ -10,12 +10,20 @@ class TestOracleResponsysConnector: def test_connection(self, oracle_responsys_runner: ConnectorRunner): oracle_responsys_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_access_request_by_email( self, + dsr_version, + request, oracle_responsys_runner: ConnectorRunner, policy, oracle_responsys_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + access_results = await oracle_responsys_runner.access_request( access_policy=policy, identities={"email": oracle_responsys_identity_email} ) @@ -26,12 +34,20 @@ async def test_access_request_by_email( == oracle_responsys_identity_email ) + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_access_request_by_phone_number( self, + dsr_version, + request, oracle_responsys_runner: ConnectorRunner, policy, oracle_responsys_identity_phone_number: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + access_results = await oracle_responsys_runner.access_request( access_policy=policy, identities={"phone_number": oracle_responsys_identity_phone_number}, @@ -43,14 +59,22 @@ async def test_access_request_by_phone_number( == oracle_responsys_identity_phone_number[1:] ) + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request_by_email( self, + dsr_version, + request, oracle_responsys_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, oracle_responsys_erasure_identity_email: str, oracle_responsys_erasure_data, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + ( access_results, erasure_results, @@ -64,14 +88,22 @@ async def test_non_strict_erasure_request_by_email( "oracle_responsys_instance:profile_list": 0, } + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request_by_phone_number( self, + dsr_version, + request, oracle_responsys_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, oracle_responsys_erasure_identity_phone_number: str, oracle_responsys_erasure_data, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + ( access_results, erasure_results, diff --git a/tests/ops/integration_tests/saas/test_outreach_task.py b/tests/ops/integration_tests/saas/test_outreach_task.py index 82695c4759..af87359082 100644 --- a/tests/ops/integration_tests/saas/test_outreach_task.py +++ b/tests/ops/integration_tests/saas/test_outreach_task.py @@ -1,15 +1,12 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.filter_results import filter_data_categories from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match @@ -22,18 +19,23 @@ def test_outreach_connection_test(outreach_connection_config) -> None: @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_outreach_access_request_task( db, + dsr_version, + request, policy, + privacy_request, outreach_connection_config, outreach_dataset_config, outreach_identity_email, ) -> None: """Full access request based on the Outreach SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_outreach_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": outreach_identity_email}) privacy_request.cache_identity(identity) @@ -41,7 +43,7 @@ async def test_outreach_access_request_task( merged_graph = outreach_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -95,9 +97,15 @@ async def test_outreach_access_request_task( @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_outreach_erasure_request_task( db, - policy, + dsr_version, + request, + privacy_request, erasure_policy_string_rewrite, outreach_connection_config, outreach_dataset_config, @@ -105,13 +113,14 @@ async def test_outreach_erasure_request_task( outreach_create_erasure_data, ) -> None: """Full erasure request based on the Outreach SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False # Allow Delete - privacy_request = PrivacyRequest( - id=f"test_outreach_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": outreach_erasure_identity_email}) privacy_request.cache_identity(identity) @@ -119,9 +128,9 @@ async def test_outreach_erasure_request_task( merged_graph = outreach_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [outreach_connection_config], {"email": outreach_erasure_identity_email}, @@ -140,7 +149,7 @@ async def test_outreach_erasure_request_task( keys=["type", "id", "attributes", "links"], ) - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_qualtrics_task.py b/tests/ops/integration_tests/saas/test_qualtrics_task.py index 80f1b86b5a..1ccf75d98f 100644 --- a/tests/ops/integration_tests/saas/test_qualtrics_task.py +++ b/tests/ops/integration_tests/saas/test_qualtrics_task.py @@ -9,13 +9,21 @@ class TestQualtricsConnector: def test_connection(self, qualtrics_runner: ConnectorRunner): qualtrics_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request( self, + dsr_version, + request, qualtrics_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, qualtrics_erasure_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + ( access_results, erasure_results, diff --git a/tests/ops/integration_tests/saas/test_recharge_tasks.py b/tests/ops/integration_tests/saas/test_recharge_tasks.py index 1405c5c334..a403950593 100644 --- a/tests/ops/integration_tests/saas/test_recharge_tasks.py +++ b/tests/ops/integration_tests/saas/test_recharge_tasks.py @@ -1,14 +1,11 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match @@ -19,18 +16,23 @@ def test_recharge_connection_test(recharge_connection_config) -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_recharge_access_request_task( db, policy, recharge_connection_config, recharge_dataset_config, recharge_identity_email, + privacy_request, + dsr_version, + request, ) -> None: """Full access request based on the Recharge SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_recharge_access_request_task_{random.randint(0, 1000)}" - ) identity_attribute = "email" identity_value = recharge_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -40,7 +42,7 @@ async def test_recharge_access_request_task( dataset_name = recharge_connection_config.get_saas_config().fides_key merged_graph = recharge_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -123,19 +125,27 @@ async def test_recharge_access_request_task( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_recharge_erasure_request_task( db, - policy, + privacy_request, erasure_policy_complete_mask, recharge_connection_config, recharge_dataset_config, recharge_erasure_identity_email, recharge_erasure_data, recharge_test_client, + dsr_version, + request, ) -> None: - privacy_request = PrivacyRequest( - id=f"test_recharge_erasure_request_task_{random.randint(0, 1000)}" - ) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_complete_mask.id + privacy_request.save(db) + identity_attribute = "email" identity_value = recharge_erasure_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -146,9 +156,9 @@ async def test_recharge_erasure_request_task( merged_graph = recharge_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_complete_mask, graph, [recharge_connection_config], {"email": recharge_erasure_identity_email}, @@ -228,7 +238,7 @@ async def test_recharge_erasure_request_task( temp_masking = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_complete_mask, graph, diff --git a/tests/ops/integration_tests/saas/test_recurly_task.py b/tests/ops/integration_tests/saas/test_recurly_task.py index 7710629441..78b376c42f 100644 --- a/tests/ops/integration_tests/saas/test_recurly_task.py +++ b/tests/ops/integration_tests/saas/test_recurly_task.py @@ -8,9 +8,20 @@ class TestRecurlyConnector: def test_connection(self, recurly_runner: ConnectorRunner): recurly_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_access_request( - self, recurly_runner: ConnectorRunner, policy, recurly_identity_email: str + self, + recurly_runner: ConnectorRunner, + policy, + dsr_version, + request, + recurly_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + access_results = await recurly_runner.access_request( access_policy=policy, identities={"email": recurly_identity_email} ) diff --git a/tests/ops/integration_tests/saas/test_rollbar_task.py b/tests/ops/integration_tests/saas/test_rollbar_task.py index 68b0144812..1af54e904f 100644 --- a/tests/ops/integration_tests/saas/test_rollbar_task.py +++ b/tests/ops/integration_tests/saas/test_rollbar_task.py @@ -1,14 +1,11 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match @@ -21,18 +18,23 @@ def test_rollbar_connection_test(rollbar_connection_config) -> None: @pytest.mark.skip(reason="Pending account resolution") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_rollbar_access_request_task( db, policy, rollbar_connection_config, rollbar_dataset_config, rollbar_identity_email, + dsr_version, + request, + privacy_request, ) -> None: """Full access request based on the Rollbar SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_rollbar_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": rollbar_identity_email}) privacy_request.cache_identity(identity) @@ -40,7 +42,7 @@ async def test_rollbar_access_request_task( merged_graph = rollbar_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -90,6 +92,10 @@ async def test_rollbar_access_request_task( @pytest.mark.skip(reason="Pending account resolution") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_rollbar_erasure_request_task( db, policy, @@ -99,12 +105,13 @@ async def test_rollbar_erasure_request_task( rollbar_erasure_identity_email, rollbar_erasure_data, rollbar_test_client, + dsr_version, + request, + privacy_request, ) -> None: """Full erasure request based on the Rollbar SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_rollbar_erasure_request_task_{random.randint(0, 1000)}" - ) identity_kwargs = {"email": rollbar_erasure_identity_email} identity = Identity(**identity_kwargs) @@ -113,7 +120,7 @@ async def test_rollbar_erasure_request_task( dataset_name = rollbar_connection_config.get_saas_config().fides_key merged_graph = rollbar_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -157,7 +164,7 @@ async def test_rollbar_erasure_request_task( temp_masking = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_salesforce_task.py b/tests/ops/integration_tests/saas/test_salesforce_task.py index ef8c17d4c8..84acb47c9f 100644 --- a/tests/ops/integration_tests/saas/test_salesforce_task.py +++ b/tests/ops/integration_tests/saas/test_salesforce_task.py @@ -1,15 +1,12 @@ -import random - import pytest import requests from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match @@ -22,7 +19,14 @@ def test_salesforce_connection_test(salesforce_connection_config) -> None: @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_salesforce_access_request_task_by_email( + privacy_request, + dsr_version, + request, policy, salesforce_identity_email, salesforce_connection_config, @@ -30,10 +34,8 @@ async def test_salesforce_access_request_task_by_email( db, ) -> None: """Full access request based on the Salesforce SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_salesforce_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": salesforce_identity_email}) privacy_request.cache_identity(identity) @@ -41,7 +43,7 @@ async def test_salesforce_access_request_task_by_email( merged_graph = salesforce_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -370,8 +372,15 @@ async def test_salesforce_access_request_task_by_email( @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_salesforce_access_request_task_by_phone_number( policy, + dsr_version, + request, + privacy_request, salesforce_identity_phone_number, salesforce_identity_email, salesforce_connection_config, @@ -379,10 +388,8 @@ async def test_salesforce_access_request_task_by_phone_number( db, ) -> None: """Full access request based on the Salesforce SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_salesforce_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"phone_number": salesforce_identity_phone_number}) privacy_request.cache_identity(identity) @@ -390,7 +397,7 @@ async def test_salesforce_access_request_task_by_phone_number( merged_graph = salesforce_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -716,9 +723,15 @@ async def test_salesforce_access_request_task_by_phone_number( @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_salesforce_erasure_request_task( db, - policy, + dsr_version, + request, + privacy_request, erasure_policy_string_rewrite, salesforce_connection_config, salesforce_dataset_config, @@ -726,6 +739,9 @@ async def test_salesforce_erasure_request_task( salesforce_create_erasure_data, ) -> None: """Full erasure request based on the Salesforce SaaS config""" + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 ( account_id, @@ -735,9 +751,6 @@ async def test_salesforce_erasure_request_task( campaign_member_id, ) = salesforce_create_erasure_data - privacy_request = PrivacyRequest( - id=f"test_salesforce_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": salesforce_erasure_identity_email}) privacy_request.cache_identity(identity) @@ -745,9 +758,9 @@ async def test_salesforce_erasure_request_task( merged_graph = salesforce_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [salesforce_connection_config], {"email": salesforce_erasure_identity_email}, @@ -1058,7 +1071,7 @@ async def test_salesforce_erasure_request_task( masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = True - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_segment_task.py b/tests/ops/integration_tests/saas/test_segment_task.py index ba7b9e97fb..64ff27434e 100644 --- a/tests/ops/integration_tests/saas/test_segment_task.py +++ b/tests/ops/integration_tests/saas/test_segment_task.py @@ -1,15 +1,12 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.filter_results import filter_data_categories from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match @@ -22,18 +19,23 @@ def test_segment_connection_test(segment_connection_config) -> None: @pytest.mark.skip(reason="Pending account resolution") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_segment_access_request_task( db, + dsr_version, + request, policy, + privacy_request, segment_connection_config, segment_dataset_config, segment_identity_email, ) -> None: """Full access request based on the Segment SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_segment_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": segment_identity_email}) privacy_request.cache_identity(identity) @@ -41,7 +43,7 @@ async def test_segment_access_request_task( merged_graph = segment_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -147,34 +149,40 @@ async def test_segment_access_request_task( @pytest.mark.skip(reason="Pending account resolution") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_segment_erasure_request_task( db, - policy, + dsr_version, + request, + erasure_policy, + privacy_request_with_erasure_policy, segment_connection_config, segment_dataset_config, segment_erasure_identity_email, segment_erasure_data, ) -> None: """Full erasure request based on the Segment SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False # Allow GDPR Delete # Create user for GDPR delete erasure_email = segment_erasure_identity_email - privacy_request = PrivacyRequest( - id=f"test_segment_access_request_task_{random.randint(0, 1000)}" - ) + identity = Identity(**{"email": erasure_email}) - privacy_request.cache_identity(identity) + privacy_request_with_erasure_policy.cache_identity(identity) dataset_name = segment_connection_config.get_saas_config().fides_key merged_graph = segment_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( - privacy_request, - policy, + v = access_runner_tester( + privacy_request_with_erasure_policy, + erasure_policy, graph, [segment_connection_config], {"email": erasure_email}, @@ -218,13 +226,13 @@ async def test_segment_erasure_request_task( ], ) - x = await graph_task.run_erasure( - privacy_request, - policy, + x = erasure_runner_tester( + privacy_request_with_erasure_policy, + erasure_policy, graph, [segment_connection_config], {"email": erasure_email}, - get_cached_data_for_erasures(privacy_request.id), + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), db, ) diff --git a/tests/ops/integration_tests/saas/test_sendgrid_task.py b/tests/ops/integration_tests/saas/test_sendgrid_task.py index 2908a6a951..0b98b5c892 100644 --- a/tests/ops/integration_tests/saas/test_sendgrid_task.py +++ b/tests/ops/integration_tests/saas/test_sendgrid_task.py @@ -1,14 +1,11 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.fixtures.saas.sendgrid_fixtures import contact_exists from tests.ops.graph.graph_test_util import assert_rows_match from tests.ops.test_helpers.saas_test_utils import poll_for_existence @@ -21,18 +18,23 @@ def test_sendgrid_connection_test(sendgrid_connection_config) -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_sendgrid_access_request_task( db, policy, + privacy_request, sendgrid_connection_config, sendgrid_dataset_config, sendgrid_identity_email, + request, + dsr_version, ) -> None: """Full access request based on the Sendgrid SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_sendgrid_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": sendgrid_identity_email}) privacy_request.cache_identity(identity) @@ -40,7 +42,7 @@ async def test_sendgrid_access_request_task( merged_graph = sendgrid_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -76,9 +78,15 @@ async def test_sendgrid_access_request_task( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_sendgrid_erasure_request_task( db, - policy, + privacy_request, + dsr_version, + request, erasure_policy_string_rewrite, sendgrid_secrets, sendgrid_connection_config, @@ -87,10 +95,11 @@ async def test_sendgrid_erasure_request_task( sendgrid_erasure_data, ) -> None: """Full erasure request based on the Sendgrid SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) - privacy_request = PrivacyRequest( - id=f"test_sendgrid_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": sendgrid_erasure_identity_email}) privacy_request.cache_identity(identity) @@ -99,9 +108,9 @@ async def test_sendgrid_erasure_request_task( graph = DatasetGraph(merged_graph) # access our erasure identity - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [sendgrid_connection_config], {"email": sendgrid_erasure_identity_email}, @@ -134,7 +143,7 @@ async def test_sendgrid_erasure_request_task( ) temp_masking = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False # Allow delete - erasure = await graph_task.run_erasure( + erasure = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_sentry_task.py b/tests/ops/integration_tests/saas/test_sentry_task.py index 200ca6f7ab..14346bfe58 100644 --- a/tests/ops/integration_tests/saas/test_sentry_task.py +++ b/tests/ops/integration_tests/saas/test_sentry_task.py @@ -1,17 +1,14 @@ -import random -import time from typing import Any, Dict, List, Optional import pytest import requests from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.filter_results import filter_data_categories from fides.api.task.graph_task import get_cached_data_for_erasures +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match from tests.ops.test_helpers.saas_test_utils import poll_for_existence @@ -25,18 +22,23 @@ def test_sentry_connection_test(sentry_connection_config) -> None: @pytest.mark.skip(reason="Pending account resolution") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_sentry_access_request_task( db, + dsr_version, + request, policy, + privacy_request, sentry_connection_config, sentry_dataset_config, sentry_identity_email, ) -> None: """Full access request based on the Sentry SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_sentry_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": sentry_identity_email}) privacy_request.cache_identity(identity) @@ -44,7 +46,7 @@ async def test_sentry_access_request_task( merged_graph = sentry_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -272,31 +274,39 @@ def sentry_erasure_test_prep(sentry_connection_config, db): @pytest.mark.skip(reason="Pending account resolution") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_sentry_erasure_request_task( - db, policy, sentry_connection_config, sentry_dataset_config + db, + dsr_version, + request, + erasure_policy, + privacy_request_with_erasure_policy, + sentry_connection_config, + sentry_dataset_config, ) -> None: """ Full erasure request based on the Sentry SaaS config. Also verifies issue data in access request. """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 erasure_email, issue_url, headers = sentry_erasure_test_prep( sentry_connection_config, db ) - privacy_request = PrivacyRequest( - id=f"test_sentry_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": erasure_email}) - privacy_request.cache_identity(identity) + privacy_request_with_erasure_policy.cache_identity(identity) dataset_name = sentry_connection_config.get_saas_config().fides_key merged_graph = sentry_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( - privacy_request, - policy, + v = access_runner_tester( + privacy_request_with_erasure_policy, + erasure_policy, graph, [sentry_connection_config], {"email": erasure_email}, @@ -375,13 +385,13 @@ async def test_sentry_erasure_request_task( assert v[f"{dataset_name}:issues"][0]["assignedTo"]["email"] == erasure_email - x = await graph_task.run_erasure( - privacy_request, - policy, + x = erasure_runner_tester( + privacy_request_with_erasure_policy, + erasure_policy, graph, [sentry_connection_config], {"email": erasure_email}, - get_cached_data_for_erasures(privacy_request.id), + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), db, ) diff --git a/tests/ops/integration_tests/saas/test_shippo_task.py b/tests/ops/integration_tests/saas/test_shippo_task.py index 62524b5a5c..a4f1c3756c 100644 --- a/tests/ops/integration_tests/saas/test_shippo_task.py +++ b/tests/ops/integration_tests/saas/test_shippo_task.py @@ -1,13 +1,10 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.config import get_config +from tests.conftest import access_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match CONFIG = get_config() @@ -20,18 +17,23 @@ def test_shippo_connection_test(shippo_connection_config) -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_shippo_access_request_task( db, policy, + privacy_request, + dsr_version, + request, shippo_connection_config, shippo_dataset_config, shippo_identity_email, ) -> None: """Full access request based on the Shippo SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_shippo_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": shippo_identity_email}) privacy_request.cache_identity(identity) @@ -39,7 +41,7 @@ async def test_shippo_access_request_task( merged_graph = shippo_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, diff --git a/tests/ops/integration_tests/saas/test_shopify_task.py b/tests/ops/integration_tests/saas/test_shopify_task.py index f69237aaf4..dccb318830 100644 --- a/tests/ops/integration_tests/saas/test_shopify_task.py +++ b/tests/ops/integration_tests/saas/test_shopify_task.py @@ -1,16 +1,14 @@ -import random from time import sleep import pytest import requests from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match @@ -21,18 +19,24 @@ def test_shopify_connection_test(shopify_connection_config) -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_shopify_access_request_task( db, policy, + dsr_version, + request, + privacy_request, shopify_connection_config, shopify_dataset_config, shopify_identity_email, # shopify_access_data, ) -> None: """Full access request based on the Shopify SaaS config""" - privacy_request = PrivacyRequest( - id=f"test_shopify_access_request_task_{random.randint(0, 1000)}" - ) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + identity = Identity(**{"email": shopify_identity_email}) privacy_request.cache_identity(identity) @@ -40,7 +44,7 @@ async def test_shopify_access_request_task( merged_graph = shopify_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -263,20 +267,27 @@ async def test_shopify_access_request_task( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_shopify_erasure_request_task( db, - policy, + privacy_request, erasure_policy_string_rewrite, shopify_connection_config, shopify_dataset_config, shopify_erasure_identity_email, + dsr_version, + request, shopify_erasure_data, ) -> None: """Full erasure request based on the Shopify SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) - privacy_request = PrivacyRequest( - id=f"test_shopify_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": shopify_erasure_identity_email}) privacy_request.cache_identity(identity) @@ -284,9 +295,9 @@ async def test_shopify_erasure_request_task( merged_graph = shopify_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [shopify_connection_config], {"email": shopify_erasure_identity_email}, @@ -478,7 +489,7 @@ async def test_shopify_erasure_request_task( temp_masking = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = True - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_simon_data_task.py b/tests/ops/integration_tests/saas/test_simon_data_task.py index 01124dd689..63b90bd6db 100644 --- a/tests/ops/integration_tests/saas/test_simon_data_task.py +++ b/tests/ops/integration_tests/saas/test_simon_data_task.py @@ -9,14 +9,22 @@ class TestSimonDataConnector: def test_connection(self, simon_data_runner: ConnectorRunner): simon_data_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request( self, + dsr_version, + request, simon_data_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, simon_data_erasure_identity_email: str, simon_data_erasure_data, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + ( access_results, erasure_results, diff --git a/tests/ops/integration_tests/saas/test_slack_enterprise_task.py b/tests/ops/integration_tests/saas/test_slack_enterprise_task.py index 463ddc6274..3a6312af84 100644 --- a/tests/ops/integration_tests/saas/test_slack_enterprise_task.py +++ b/tests/ops/integration_tests/saas/test_slack_enterprise_task.py @@ -1,12 +1,9 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task +from tests.conftest import access_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match @@ -17,18 +14,23 @@ def test_slack_enterprise_connection_test(slack_enterprise_connection_config) -> @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_slack_enterprise_access_request_task( db, policy, + privacy_request, + dsr_version, + request, slack_enterprise_connection_config, slack_enterprise_dataset_config, slack_enterprise_identity_email, ) -> None: """Full access request based on the Slack Enterprise SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_slack_enterprise_access_request_task_{random.randint(0, 1000)}" - ) identity_attribute = "email" identity_value = slack_enterprise_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -38,7 +40,7 @@ async def test_slack_enterprise_access_request_task( merged_graph = slack_enterprise_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, diff --git a/tests/ops/integration_tests/saas/test_sparkpost_task.py b/tests/ops/integration_tests/saas/test_sparkpost_task.py index dd1e09c3d1..e224ca8444 100644 --- a/tests/ops/integration_tests/saas/test_sparkpost_task.py +++ b/tests/ops/integration_tests/saas/test_sparkpost_task.py @@ -9,9 +9,20 @@ class TestSparkPostConnector: def test_connection(self, sparkpost_runner: ConnectorRunner): sparkpost_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_access_request( - self, sparkpost_runner: ConnectorRunner, policy, sparkpost_identity_email: str + self, + sparkpost_runner: ConnectorRunner, + policy, + dsr_version, + request, + sparkpost_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + access_results = await sparkpost_runner.access_request( access_policy=policy, identities={"email": sparkpost_identity_email} ) @@ -20,14 +31,22 @@ async def test_access_request( == sparkpost_identity_email ) + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request( self, + dsr_version, + request, sparkpost_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, sparkpost_erasure_identity_email: str, sparkpost_erasure_data, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + ( access_results, erasure_results, diff --git a/tests/ops/integration_tests/saas/test_square_task.py b/tests/ops/integration_tests/saas/test_square_task.py index a9e73a9b6c..196f24c3ea 100644 --- a/tests/ops/integration_tests/saas/test_square_task.py +++ b/tests/ops/integration_tests/saas/test_square_task.py @@ -1,16 +1,15 @@ -import random from time import sleep import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match +from tests.ops.test_helpers.cache_secrets_helper import clear_cache_identities @pytest.mark.integration_saas @@ -20,18 +19,23 @@ def test_square_connection_test(square_connection_config) -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_square_access_request_task_by_email( db, policy, + privacy_request, + dsr_version, + request, square_connection_config, square_dataset_config, square_identity_email, ) -> None: """Full access request based on the Square SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_square_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": square_identity_email}) privacy_request.cache_identity(identity) @@ -39,7 +43,7 @@ async def test_square_access_request_task_by_email( merged_graph = square_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -103,19 +107,27 @@ async def test_square_access_request_task_by_email( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_square_access_request_task_by_phone_number( db, policy, + dsr_version, + request, + privacy_request, square_connection_config, square_dataset_config, square_identity_email, square_identity_phone_number, ) -> None: """Full access request based on the Square SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + # Privacy request fixture already caches an email and a phone number, so + # clearing those first + clear_cache_identities(privacy_request.id) - privacy_request = PrivacyRequest( - id=f"test_square_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"phone_number": square_identity_phone_number}) privacy_request.cache_identity(identity) @@ -123,7 +135,7 @@ async def test_square_access_request_task_by_phone_number( merged_graph = square_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -161,19 +173,24 @@ async def test_square_access_request_task_by_phone_number( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_square_access_request_task_with_multiple_identities( db, policy, + dsr_version, + request, + privacy_request, square_connection_config, square_dataset_config, square_identity_email, square_identity_phone_number, ) -> None: """Full access request based on the Square SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_square_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity( **{"email": square_identity_email, "phone_number": square_identity_phone_number} ) @@ -183,7 +200,7 @@ async def test_square_access_request_task_with_multiple_identities( merged_graph = square_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -226,9 +243,15 @@ async def test_square_access_request_task_with_multiple_identities( @pytest.mark.integration_saas @pytest.mark.integration_square @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_square_erasure_request_task( db, - policy, + privacy_request, + dsr_version, + request, erasure_policy_string_rewrite, square_connection_config, square_dataset_config, @@ -237,10 +260,11 @@ async def test_square_erasure_request_task( square_test_client, ) -> None: """Full erasure request based on the Square SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) - privacy_request = PrivacyRequest( - id=f"test_square_erasure_request_task_{random.randint(0, 1000)}" - ) identity_kwargs = {"email": square_erasure_identity_email} identity = Identity(**identity_kwargs) privacy_request.cache_identity(identity) @@ -248,9 +272,9 @@ async def test_square_erasure_request_task( dataset_name = square_connection_config.get_saas_config().fides_key merged_graph = square_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [square_connection_config], identity_kwargs, @@ -310,7 +334,7 @@ async def test_square_erasure_request_task( temp_masking = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = True - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_statsig_enterprise_task.py b/tests/ops/integration_tests/saas/test_statsig_enterprise_task.py index efeaf4c996..0b2da4f1f6 100644 --- a/tests/ops/integration_tests/saas/test_statsig_enterprise_task.py +++ b/tests/ops/integration_tests/saas/test_statsig_enterprise_task.py @@ -10,13 +10,21 @@ def test_connection(self, statsig_enterprise_runner: ConnectorRunner): statsig_enterprise_runner.test_connection() @pytest.mark.skip(reason="Enterprise account only") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request( self, + dsr_version, + request, statsig_enterprise_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, statsig_enterprise_erasure_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + ( _, erasure_results, diff --git a/tests/ops/integration_tests/saas/test_stripe_task.py b/tests/ops/integration_tests/saas/test_stripe_task.py index 80296edd3d..ff0448c3db 100644 --- a/tests/ops/integration_tests/saas/test_stripe_task.py +++ b/tests/ops/integration_tests/saas/test_stripe_task.py @@ -1,18 +1,17 @@ -import random from typing import List import pytest import requests from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.filter_results import filter_data_categories from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match +from tests.ops.test_helpers.cache_secrets_helper import clear_cache_identities @pytest.mark.integration_saas @@ -22,18 +21,23 @@ def test_stripe_connection_test(stripe_connection_config) -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_stripe_access_request_task_with_email( db, policy, + dsr_version, + request, + privacy_request, stripe_connection_config, stripe_dataset_config, stripe_identity_email, ) -> None: """Full access request based on the Stripe SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_stripe_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": stripe_identity_email}) privacy_request.cache_identity(identity) @@ -41,7 +45,7 @@ async def test_stripe_access_request_task_with_email( merged_graph = stripe_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -592,19 +596,27 @@ async def test_stripe_access_request_task_with_email( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_stripe_access_request_task_with_phone_number( db, policy, + dsr_version, + request, + privacy_request, stripe_connection_config, stripe_dataset_config, stripe_identity_email, stripe_identity_phone_number, ) -> None: """Full access request based on the Stripe SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + # The Privacy request fixture we're using already has an email/phone cached + # so I'm clearing that first + clear_cache_identities(privacy_request.id) - privacy_request = PrivacyRequest( - id=f"test_stripe_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"phone_number": stripe_identity_phone_number}) privacy_request.cache_identity(identity) @@ -612,7 +624,7 @@ async def test_stripe_access_request_task_with_phone_number( merged_graph = stripe_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -657,9 +669,15 @@ async def test_stripe_access_request_task_with_phone_number( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_stripe_erasure_request_task( db, - policy, + privacy_request, + dsr_version, + request, erasure_policy_string_rewrite, stripe_connection_config, stripe_dataset_config, @@ -667,10 +685,11 @@ async def test_stripe_erasure_request_task( stripe_create_erasure_data, ) -> None: """Full erasure request based on the Stripe SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) - privacy_request = PrivacyRequest( - id=f"test_stripe_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": stripe_erasure_identity_email}) privacy_request.cache_identity(identity) @@ -678,9 +697,9 @@ async def test_stripe_erasure_request_task( merged_graph = stripe_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [stripe_connection_config], {"email": stripe_erasure_identity_email}, @@ -1099,7 +1118,7 @@ async def test_stripe_erasure_request_task( masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_surveymonkey_task.py b/tests/ops/integration_tests/saas/test_surveymonkey_task.py index d874c23e7f..a6c6b26a67 100644 --- a/tests/ops/integration_tests/saas/test_surveymonkey_task.py +++ b/tests/ops/integration_tests/saas/test_surveymonkey_task.py @@ -10,12 +10,20 @@ class TestSurveyMonkeyConnector: def test_connection(self, surveymonkey_runner: ConnectorRunner): surveymonkey_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_access_request( self, + dsr_version, + request, surveymonkey_runner: ConnectorRunner, policy, surveymonkey_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + access_results = await surveymonkey_runner.access_request( access_policy=policy, identities={"email": surveymonkey_identity_email} ) @@ -30,8 +38,14 @@ async def test_access_request( == surveymonkey_identity_email ) + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request( self, + dsr_version, + request, surveymonkey_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, @@ -39,6 +53,8 @@ async def test_non_strict_erasure_request( surveymonkey_erasure_data, surveymonkey_client, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + ( access_results, erasure_results, diff --git a/tests/ops/integration_tests/saas/test_talkable_task.py b/tests/ops/integration_tests/saas/test_talkable_task.py index 7410944077..8aa0da8d71 100644 --- a/tests/ops/integration_tests/saas/test_talkable_task.py +++ b/tests/ops/integration_tests/saas/test_talkable_task.py @@ -9,9 +9,20 @@ class TestTalkableConnector: def test_connection(self, talkable_runner: ConnectorRunner): talkable_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_access_request( - self, talkable_runner: ConnectorRunner, policy, talkable_identity_email: str + self, + talkable_runner: ConnectorRunner, + policy, + request, + dsr_version, + talkable_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + access_results = await talkable_runner.access_request( access_policy=policy, identities={"email": talkable_identity_email} ) @@ -23,14 +34,22 @@ async def test_access_request( ) @pytest.mark.skip(reason="Temporarily disabled test") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request( self, talkable_runner: ConnectorRunner, policy: Policy, + request, + dsr_version, erasure_policy_string_rewrite: Policy, talkable_erasure_identity_email: str, talkable_erasure_data, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + ( access_results, erasure_results, diff --git a/tests/ops/integration_tests/saas/test_twilio_conversations_task.py b/tests/ops/integration_tests/saas/test_twilio_conversations_task.py index 48a7fc8b51..6fb3a3983e 100644 --- a/tests/ops/integration_tests/saas/test_twilio_conversations_task.py +++ b/tests/ops/integration_tests/saas/test_twilio_conversations_task.py @@ -1,15 +1,12 @@ -import random - import pytest import requests from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match @@ -22,9 +19,16 @@ def test_twilio_conversations_connection_test( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_twilio_conversations_access_request_task( db, policy, + dsr_version, + request, + privacy_request, twilio_conversations_connection_config, twilio_conversations_dataset_config, twilio_conversations_identity_email, @@ -33,9 +37,8 @@ async def test_twilio_conversations_access_request_task( twilio_postgres_db, ) -> None: """Full access request based on the Twilio Conversations SaaS config""" - privacy_request = PrivacyRequest( - id=f"test_twilio_conversations_access_request_task_{random.randint(0, 1000)}" - ) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + identity_attribute = "email" identity_value = twilio_conversations_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -46,7 +49,7 @@ async def test_twilio_conversations_access_request_task( merged_graph = twilio_conversations_dataset_config.get_graph() graph = DatasetGraph(*[merged_graph, twilio_postgres_dataset_config.get_graph()]) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -137,9 +140,15 @@ async def test_twilio_conversations_access_request_task( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_twilio_conversations_erasure_request_task( db, - policy, + dsr_version, + request, + privacy_request, twilio_conversations_connection_config, twilio_conversations_dataset_config, connection_config, @@ -151,10 +160,11 @@ async def test_twilio_conversations_erasure_request_task( twilio_conversations_erasure_data, ) -> None: """Full erasure request based on the Twilio SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) - privacy_request = PrivacyRequest( - id=f"test_twilio_conversations_access_request_task_{random.randint(0, 1000)}" - ) identity_attribute = "email" identity_value = twilio_conversations_erasure_identity_email identity_kwargs = {identity_attribute: identity_value} @@ -165,9 +175,9 @@ async def test_twilio_conversations_erasure_request_task( merged_graph = twilio_conversations_dataset_config.get_graph() graph = DatasetGraph(*[merged_graph, twilio_postgres_dataset_config.get_graph()]) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [twilio_conversations_connection_config, connection_config], {"email": twilio_conversations_erasure_identity_email}, @@ -258,7 +268,7 @@ async def test_twilio_conversations_erasure_request_task( temp_masking = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = True - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_typeform_task.py b/tests/ops/integration_tests/saas/test_typeform_task.py index a0a1b0498b..4f86a852eb 100644 --- a/tests/ops/integration_tests/saas/test_typeform_task.py +++ b/tests/ops/integration_tests/saas/test_typeform_task.py @@ -9,13 +9,21 @@ class TestTypeformConnector: def test_connection(self, typeform_runner: ConnectorRunner): typeform_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request( self, + dsr_version, + request, typeform_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, typeform_erasure_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + ( access_results, erasure_results, diff --git a/tests/ops/integration_tests/saas/test_unbounce_task.py b/tests/ops/integration_tests/saas/test_unbounce_task.py index 8758c45d35..4a62857d2a 100644 --- a/tests/ops/integration_tests/saas/test_unbounce_task.py +++ b/tests/ops/integration_tests/saas/test_unbounce_task.py @@ -1,14 +1,11 @@ -import random - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import get_config +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match CONFIG = get_config() @@ -23,18 +20,23 @@ def test_unbounce_connection_test(unbounce_connection_config) -> None: @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_unbounce_access_request_task( db, + dsr_version, + request, policy, + privacy_request, unbounce_connection_config, unbounce_dataset_config, unbounce_identity_email, ) -> None: """Full access request based on the Unbounce SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_unbounce_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": unbounce_identity_email}) privacy_request.cache_identity(identity) @@ -42,7 +44,7 @@ async def test_unbounce_access_request_task( merged_graph = unbounce_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -93,9 +95,15 @@ async def test_unbounce_access_request_task( @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_unbounce_erasure_request_task( db, - policy, + dsr_version, + request, + privacy_request, erasure_policy_string_rewrite, unbounce_connection_config, unbounce_dataset_config, @@ -103,13 +111,14 @@ async def test_unbounce_erasure_request_task( unbounce_create_erasure_data, ) -> None: """Full erasure request based on the Unbounce SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False # Allow Delete - privacy_request = PrivacyRequest( - id=f"test_unbounce_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": unbounce_erasure_identity_email}) privacy_request.cache_identity(identity) @@ -117,9 +126,9 @@ async def test_unbounce_erasure_request_task( merged_graph = unbounce_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [unbounce_connection_config], {"email": unbounce_erasure_identity_email}, @@ -160,7 +169,7 @@ async def test_unbounce_erasure_request_task( ], ) - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_universal_analytics_task.py b/tests/ops/integration_tests/saas/test_universal_analytics_task.py index 9518d49fe3..e6d2b1f9e0 100644 --- a/tests/ops/integration_tests/saas/test_universal_analytics_task.py +++ b/tests/ops/integration_tests/saas/test_universal_analytics_task.py @@ -1,22 +1,16 @@ from unittest import mock -from uuid import uuid4 import pytest from fides.api.models.policy import ActionType -from fides.api.models.privacy_request import ( - ExecutionLog, - ExecutionLogStatus, - PrivacyRequest, - PrivacyRequestStatus, -) +from fides.api.models.privacy_request import ExecutionLog, ExecutionLogStatus from fides.api.schemas.redis_cache import Identity from fides.api.schemas.saas.shared_schemas import SaaSRequestParams from fides.api.service.connectors import get_connector from fides.api.service.privacy_request.request_runner_service import ( build_consent_dataset_graph, ) -from fides.api.task import graph_task +from tests.conftest import consent_runner_tester @pytest.mark.integration_saas @@ -30,26 +24,32 @@ def test_universal_analytics_connection_test( @pytest.mark.integration_saas @pytest.mark.asyncio @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_universal_analytics_consent_request_task_old_workflow( db, + privacy_request, consent_policy, + dsr_version, + request, universal_analytics_connection_config, universal_analytics_dataset_config, universal_analytics_client_id, ) -> None: """Full consent request based on the Google Analytics SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=str(uuid4()), - consent_preferences=[{"data_use": "marketing.advertising", "opt_in": False}], - ) + privacy_request.policy_id = consent_policy.id + privacy_request.save(db) identity = Identity(**{"ga_client_id": universal_analytics_client_id}) privacy_request.cache_identity(identity) dataset_name = "universal_analytics_instance" - v = await graph_task.run_consent_request( + v = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([universal_analytics_dataset_config]), @@ -88,25 +88,31 @@ async def test_universal_analytics_consent_request_task_old_workflow( @pytest.mark.asyncio @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_universal_analytics_consent_prepared_requests_old_workflow( mocked_client_send, db, + privacy_request, consent_policy, universal_analytics_connection_config, universal_analytics_dataset_config, universal_analytics_client_id, + dsr_version, + request, ) -> None: """Assert attributes of the PreparedRequest created by the client for running the consent request""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=str(uuid4()), - consent_preferences=[{"data_use": "marketing.advertising", "opt_in": False}], - ) + privacy_request.policy_id = consent_policy.id + privacy_request.save(db) identity = Identity(**{"ga_client_id": universal_analytics_client_id}) privacy_request.cache_identity(identity) - await graph_task.run_consent_request( + consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([universal_analytics_dataset_config]), @@ -131,26 +137,32 @@ async def test_universal_analytics_consent_prepared_requests_old_workflow( @pytest.mark.asyncio @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_universal_analytics_no_ga_client_id_old_workflow( mocked_client_send, db, + privacy_request, consent_policy, universal_analytics_connection_config, universal_analytics_dataset_config, + dsr_version, + request, ) -> None: """Test that the universal analytics connector does not fail if there is no ga_client_id We skip the request because it is marked as skip_missing_param_values=True. We won't always have this piece of identity data. """ + privacy_request.policy_id = consent_policy.id + privacy_request.save(db) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=str(uuid4()), - consent_preferences=[{"data_use": "marketing.advertising", "opt_in": False}], - ) dataset_name = "universal_analytics_instance" - v = await graph_task.run_consent_request( + v = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([universal_analytics_dataset_config]), @@ -182,25 +194,32 @@ async def test_universal_analytics_no_ga_client_id_old_workflow( @pytest.mark.integration_saas @pytest.mark.asyncio @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_universal_analytics_no_ga_client_id_new_workflow( mocked_client_send, db, + dsr_version, + request, + privacy_request, consent_policy, universal_analytics_connection_config_without_secrets, universal_analytics_dataset_config_without_secrets, privacy_preference_history, ) -> None: """Test universal analytics connector skips instead of fails if identity missing.""" - privacy_request = PrivacyRequest( - id=str(uuid4()), status=PrivacyRequestStatus.pending - ) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = consent_policy.id privacy_request.save(db) privacy_preference_history.privacy_request_id = privacy_request.id privacy_preference_history.save(db=db) dataset_name = "universal_analytics_instance" - v = await graph_task.run_consent_request( + v = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph( @@ -245,8 +264,15 @@ async def test_universal_analytics_no_ga_client_id_new_workflow( @pytest.mark.integration_saas @pytest.mark.asyncio @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_universal_analytics_consent_request_task_new_workflow( db, + dsr_version, + request, + privacy_request, consent_policy, universal_analytics_connection_config, universal_analytics_dataset_config, @@ -258,12 +284,14 @@ async def test_universal_analytics_consent_request_task_new_workflow( """Full consent request based on the Google Analytics SaaS config for the new workflow where we save consent with respect to privacy preferences """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = consent_policy.id + privacy_request.save(db) + universal_analytics_connection_config.system_id = system.id universal_analytics_connection_config.save(db) - privacy_request = PrivacyRequest( - id=str(uuid4()), status=PrivacyRequestStatus.pending - ) privacy_request.save(db) # This preference matches on data use privacy_preference_history.privacy_request_id = privacy_request.id @@ -278,7 +306,7 @@ async def test_universal_analytics_consent_request_task_new_workflow( dataset_name = "universal_analytics_instance" - v = await graph_task.run_consent_request( + v = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([universal_analytics_dataset_config]), @@ -331,9 +359,16 @@ async def test_universal_analytics_consent_request_task_new_workflow( @pytest.mark.asyncio @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_universal_analytics_consent_request_task_new_errored_workflow( mocked_client_send, db, + dsr_version, + request, + privacy_request, consent_policy, universal_analytics_connection_config, universal_analytics_dataset_config, @@ -347,14 +382,16 @@ async def test_universal_analytics_consent_request_task_new_errored_workflow( Assert logging created appropriately """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = consent_policy.id + privacy_request.save(db) + mocked_client_send.side_effect = Exception("KeyError") universal_analytics_connection_config.system_id = system.id universal_analytics_connection_config.save(db) - privacy_request = PrivacyRequest( - id=str(uuid4()), status=PrivacyRequestStatus.pending - ) privacy_request.save(db) # This preference matches on data use privacy_preference_history.privacy_request_id = privacy_request.id @@ -370,7 +407,7 @@ async def test_universal_analytics_consent_request_task_new_errored_workflow( dataset_name = "universal_analytics_instance" with pytest.raises(Exception): - await graph_task.run_consent_request( + consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([universal_analytics_dataset_config]), @@ -412,9 +449,16 @@ async def test_universal_analytics_consent_request_task_new_errored_workflow( @pytest.mark.asyncio @pytest.mark.skip(reason="Currently unable to test OAuth2 connectors") @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_universal_analytics_consent_prepared_requests_new_workflow( mocked_client_send, db, + dsr_version, + request, + privacy_request, consent_policy, universal_analytics_connection_config, universal_analytics_dataset_config, @@ -422,10 +466,11 @@ async def test_universal_analytics_consent_prepared_requests_new_workflow( privacy_preference_history, ) -> None: """Assert attributes of the PreparedRequest created by the client for running the consent request""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = consent_policy.id + privacy_request.save(db) - privacy_request = PrivacyRequest( - id=str(uuid4()), status=PrivacyRequestStatus.pending - ) privacy_request.save(db) privacy_preference_history.privacy_request_id = privacy_request.id privacy_preference_history.save(db=db) @@ -433,7 +478,7 @@ async def test_universal_analytics_consent_prepared_requests_new_workflow( identity = Identity(**{"ga_client_id": universal_analytics_client_id}) privacy_request.cache_identity(identity) - await graph_task.run_consent_request( + consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([universal_analytics_dataset_config]), diff --git a/tests/ops/integration_tests/saas/test_vend_task.py b/tests/ops/integration_tests/saas/test_vend_task.py index ff6b3d3121..174d1bb969 100644 --- a/tests/ops/integration_tests/saas/test_vend_task.py +++ b/tests/ops/integration_tests/saas/test_vend_task.py @@ -1,15 +1,12 @@ -import random - import pytest import requests from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import get_config +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match CONFIG = get_config() @@ -24,18 +21,23 @@ def test_vend_connection_test(vend_connection_config) -> None: @pytest.mark.skip(reason="No active account") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_vend_access_request_task( db, + dsr_version, + request, policy, + privacy_request, vend_connection_config, vend_dataset_config, vend_identity_email, ) -> None: """Full access request based on the Vend SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_vend_access_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": vend_identity_email}) privacy_request.cache_identity(identity) @@ -43,7 +45,7 @@ async def test_vend_access_request_task( merged_graph = vend_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -156,9 +158,15 @@ async def test_vend_access_request_task( @pytest.mark.skip(reason="No active account") @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_vend_erasure_request_task( db, - policy, + dsr_version, + request, + privacy_request, erasure_policy_string_rewrite, vend_connection_config, vend_dataset_config, @@ -166,13 +174,11 @@ async def test_vend_erasure_request_task( vend_create_erasure_data, ) -> None: """Full erasure request based on the Vend SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False # Allow Delete - privacy_request = PrivacyRequest( - id=f"test_vend_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": vend_erasure_identity_email}) privacy_request.cache_identity(identity) @@ -180,9 +186,9 @@ async def test_vend_erasure_request_task( merged_graph = vend_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [vend_connection_config], {"email": vend_erasure_identity_email}, @@ -281,7 +287,7 @@ async def test_vend_erasure_request_task( ], ) - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_wunderkind_task.py b/tests/ops/integration_tests/saas/test_wunderkind_task.py index d198ca36fa..f7304c408e 100644 --- a/tests/ops/integration_tests/saas/test_wunderkind_task.py +++ b/tests/ops/integration_tests/saas/test_wunderkind_task.py @@ -1,22 +1,16 @@ from unittest import mock -from uuid import uuid4 import pytest from fides.api.models.policy import ActionType -from fides.api.models.privacy_request import ( - ExecutionLog, - ExecutionLogStatus, - PrivacyRequest, - PrivacyRequestStatus, -) +from fides.api.models.privacy_request import ExecutionLog, ExecutionLogStatus from fides.api.schemas.redis_cache import Identity from fides.api.schemas.saas.shared_schemas import SaaSRequestParams from fides.api.service.connectors import get_connector from fides.api.service.privacy_request.request_runner_service import ( build_consent_dataset_graph, ) -from fides.api.task import graph_task +from tests.conftest import consent_runner_tester @pytest.mark.integration_saas @@ -28,26 +22,34 @@ def test_wunderkind_connection_test( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_wunderkind_consent_request_task_old_workflow( db, consent_policy, wunderkind_connection_config, wunderkind_dataset_config, wunderkind_identity_email, + privacy_request, + dsr_version, + request, ) -> None: """Full consent request based on the Wunderkind SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=str(uuid4()), - consent_preferences=[{"data_use": "marketing.advertising", "opt_in": False}], - ) + privacy_request.consent_preferences = [ + {"data_use": "marketing.advertising", "opt_in": False} + ] + privacy_request.save(db) identity = Identity(**{"email": wunderkind_identity_email}) privacy_request.cache_identity(identity) dataset_name = "wunderkind_instance" - v = await graph_task.run_consent_request( + v = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([wunderkind_dataset_config]), @@ -85,6 +87,10 @@ async def test_wunderkind_consent_request_task_old_workflow( @pytest.mark.integration_saas @pytest.mark.asyncio @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_wunderkind_consent_prepared_requests_old_workflow( mocked_client_send, db, @@ -92,18 +98,22 @@ async def test_wunderkind_consent_prepared_requests_old_workflow( wunderkind_connection_config, wunderkind_dataset_config, wunderkind_identity_email, + dsr_version, + request, + privacy_request, ) -> None: """Assert attributes of the PreparedRequest created by the client for running the consent request""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=str(uuid4()), - consent_preferences=[{"data_use": "marketing.advertising", "opt_in": False}], - ) + privacy_request.consent_preferences = [ + {"data_use": "marketing.advertising", "opt_in": False} + ] + privacy_request.save(db) identity = Identity(**{"email": wunderkind_identity_email}) privacy_request.cache_identity(identity) - await graph_task.run_consent_request( + consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([wunderkind_dataset_config]), @@ -126,6 +136,10 @@ async def test_wunderkind_consent_prepared_requests_old_workflow( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_wunderkind_consent_request_task_new_workflow( db, consent_policy, @@ -135,16 +149,16 @@ async def test_wunderkind_consent_request_task_new_workflow( privacy_preference_history, privacy_preference_history_us_ca_provide, system, + dsr_version, + request, + privacy_request, ) -> None: """Full consent request based on the Wunderkind SaaS config and new workflow""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 wunderkind_connection_config.system_id = system.id wunderkind_connection_config.save(db) - privacy_request = PrivacyRequest( - id=str(uuid4()), status=PrivacyRequestStatus.pending - ) - privacy_request.save(db) # This preference matches on data use privacy_preference_history.privacy_request_id = privacy_request.id privacy_preference_history.save(db=db) @@ -158,7 +172,7 @@ async def test_wunderkind_consent_request_task_new_workflow( dataset_name = "wunderkind_instance" - v = await graph_task.run_consent_request( + v = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([wunderkind_dataset_config]), @@ -210,6 +224,10 @@ async def test_wunderkind_consent_request_task_new_workflow( @pytest.mark.integration_saas @pytest.mark.asyncio @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_wunderkind_errored_logging_new_workflow( mocked_client_send, db, @@ -220,16 +238,17 @@ async def test_wunderkind_errored_logging_new_workflow( privacy_preference_history, privacy_preference_history_us_ca_provide, system, + dsr_version, + request, + privacy_request, ) -> None: """Test wunderkind errors have proper logs created""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + mocked_client_send.side_effect = Exception("KeyError") wunderkind_connection_config.system_id = system.id wunderkind_connection_config.save(db) - privacy_request = PrivacyRequest( - id=str(uuid4()), status=PrivacyRequestStatus.pending - ) - privacy_request.save(db) # This preference matches on data use privacy_preference_history.privacy_request_id = privacy_request.id privacy_preference_history.save(db=db) @@ -243,8 +262,18 @@ async def test_wunderkind_errored_logging_new_workflow( dataset_name = "wunderkind_instance" - with pytest.raises(Exception): - await graph_task.run_consent_request( + if dsr_version == "use_dsr_2_0": + with pytest.raises(Exception): + consent_runner_tester( + privacy_request, + consent_policy, + build_consent_dataset_graph([wunderkind_dataset_config]), + [wunderkind_connection_config], + {"email": wunderkind_identity_email}, + db, + ) + else: + consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([wunderkind_dataset_config]), @@ -252,6 +281,12 @@ async def test_wunderkind_errored_logging_new_workflow( {"email": wunderkind_identity_email}, db, ) + # Current task and terminator task were marked as error + assert [rt.status.value for rt in privacy_request.consent_tasks] == [ + "complete", + "error", + "error", + ] execution_logs = db.query(ExecutionLog).filter_by( privacy_request_id=privacy_request.id @@ -285,6 +320,10 @@ async def test_wunderkind_errored_logging_new_workflow( @pytest.mark.integration_saas @pytest.mark.asyncio @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_wunderkind_consent_prepared_requests_new_workflow( mocked_client_send, db, @@ -293,20 +332,20 @@ async def test_wunderkind_consent_prepared_requests_new_workflow( wunderkind_dataset_config, wunderkind_identity_email, privacy_preference_history, + dsr_version, + request, + privacy_request, ) -> None: """Assert attributes of the PreparedRequest created by the client for running the consent request""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=str(uuid4()), status=PrivacyRequestStatus.pending - ) - privacy_request.save(db) privacy_preference_history.privacy_request_id = privacy_request.id privacy_preference_history.save(db=db) identity = Identity(**{"email": wunderkind_identity_email}) privacy_request.cache_identity(identity) - await graph_task.run_consent_request( + consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([wunderkind_dataset_config]), @@ -329,6 +368,10 @@ async def test_wunderkind_consent_prepared_requests_new_workflow( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_wunderkind_skipped_new_workflow( db, consent_policy, @@ -337,15 +380,16 @@ async def test_wunderkind_skipped_new_workflow( wunderkind_identity_email, system, privacy_preference_history_us_ca_provide, + dsr_version, + request, + privacy_request, ) -> None: """Data use mismatch between notice and system should cause request to not fire""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + wunderkind_connection_config.system_id = system.id wunderkind_connection_config.save(db) - privacy_request = PrivacyRequest( - id=str(uuid4()), status=PrivacyRequestStatus.pending - ) - privacy_request.save(db) privacy_preference_history_us_ca_provide.privacy_request_id = privacy_request.id privacy_preference_history_us_ca_provide.save(db=db) @@ -354,7 +398,7 @@ async def test_wunderkind_skipped_new_workflow( dataset_name = "wunderkind_instance" - v = await graph_task.run_consent_request( + v = consent_runner_tester( privacy_request, consent_policy, build_consent_dataset_graph([wunderkind_dataset_config]), diff --git a/tests/ops/integration_tests/saas/test_yotpo_loyalty_task.py b/tests/ops/integration_tests/saas/test_yotpo_loyalty_task.py index 7a95a10f54..f185cc9096 100644 --- a/tests/ops/integration_tests/saas/test_yotpo_loyalty_task.py +++ b/tests/ops/integration_tests/saas/test_yotpo_loyalty_task.py @@ -1,16 +1,13 @@ -import random -from time import sleep - import pytest from fides.api.graph.graph import DatasetGraph -from fides.api.models.privacy_request import PrivacyRequest from fides.api.schemas.redis_cache import Identity from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import get_config +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match +from tests.ops.test_helpers.cache_secrets_helper import clear_cache_identities from tests.ops.test_helpers.saas_test_utils import poll_for_existence CONFIG = get_config() @@ -23,18 +20,23 @@ def test_yotpo_loyalty_connection_test(yotpo_loyalty_connection_config) -> None: @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_yotpo_loyalty_access_request_task_with_email( db, + privacy_request, policy, + dsr_version, + request, yotpo_loyalty_connection_config, yotpo_loyalty_dataset_config, yotpo_loyalty_identity_email, ) -> None: """Full access request based on the Yotpo Loyalty & Referrals SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest( - id=f"test_yotpo_loyalty_access_request_task_with_email_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": yotpo_loyalty_identity_email}) privacy_request.cache_identity(identity) @@ -42,7 +44,7 @@ async def test_yotpo_loyalty_access_request_task_with_email( merged_graph = yotpo_loyalty_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -86,18 +88,25 @@ async def test_yotpo_loyalty_access_request_task_with_email( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_yotpo_loyalty_access_request_task_with_phone_number( db, policy, + dsr_version, + request, + privacy_request, yotpo_loyalty_connection_config, yotpo_loyalty_dataset_config, yotpo_loyalty_identity_phone_number, ) -> None: """Full access request based on the Yotpo Loyalty & Referrals SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + # Privacy request fixture caches email and phone identities - clearing this first - + clear_cache_identities(privacy_request.id) - privacy_request = PrivacyRequest( - id=f"test_yotpo_loyalty_access_request_task_with_phone_number_{random.randint(0, 1000)}" - ) identity = Identity(**{"phone_number": yotpo_loyalty_identity_phone_number}) privacy_request.cache_identity(identity) @@ -105,7 +114,7 @@ async def test_yotpo_loyalty_access_request_task_with_phone_number( merged_graph = yotpo_loyalty_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, graph, @@ -152,9 +161,15 @@ async def test_yotpo_loyalty_access_request_task_with_phone_number( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_yotpo_loyalty_erasure_request_task( db, - policy, + dsr_version, + request, + privacy_request, erasure_policy_string_rewrite, yotpo_loyalty_connection_config, yotpo_loyalty_dataset_config, @@ -163,13 +178,14 @@ async def test_yotpo_loyalty_erasure_request_task( yotpo_loyalty_test_client, ) -> None: """Full erasure request based on the Yotpo Loyalty & Referrals SaaS config""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + privacy_request.policy_id = erasure_policy_string_rewrite.id + privacy_request.save(db) masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False - privacy_request = PrivacyRequest( - id=f"test_yotpo_loyalty_erasure_request_task_{random.randint(0, 1000)}" - ) identity = Identity(**{"email": yotpo_loyalty_erasure_identity_email}) privacy_request.cache_identity(identity) @@ -177,9 +193,9 @@ async def test_yotpo_loyalty_erasure_request_task( merged_graph = yotpo_loyalty_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_string_rewrite, graph, [yotpo_loyalty_connection_config], {"email": yotpo_loyalty_erasure_identity_email}, @@ -215,7 +231,7 @@ async def test_yotpo_loyalty_erasure_request_task( ], ) - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_string_rewrite, graph, diff --git a/tests/ops/integration_tests/saas/test_yotpo_reviews_task.py b/tests/ops/integration_tests/saas/test_yotpo_reviews_task.py index 229dc18c12..d361a94b88 100644 --- a/tests/ops/integration_tests/saas/test_yotpo_reviews_task.py +++ b/tests/ops/integration_tests/saas/test_yotpo_reviews_task.py @@ -11,25 +11,41 @@ class TestYotpoReviewsConnector: def test_connection(self, yotpo_reviews_runner: ConnectorRunner): yotpo_reviews_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_access_request( self, + dsr_version, + request, yotpo_reviews_runner: ConnectorRunner, policy, yotpo_reviews_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + await yotpo_reviews_runner.access_request( access_policy=policy, identities={"email": yotpo_reviews_identity_email} ) @pytest.mark.skip(reason="Temporarily disabled test") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_strict_erasure_request( self, + dsr_version, + request, yotpo_reviews_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, yotpo_reviews_erasure_data, yotpo_reviews_test_client: YotpoReviewsTestClient, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + email, external_id = yotpo_reviews_erasure_data (_, erasure_results) = await yotpo_reviews_runner.strict_erasure_request( access_policy=policy, diff --git a/tests/ops/integration_tests/saas/test_zendesk_task.py b/tests/ops/integration_tests/saas/test_zendesk_task.py index 4b608fdef6..89ff64bffd 100644 --- a/tests/ops/integration_tests/saas/test_zendesk_task.py +++ b/tests/ops/integration_tests/saas/test_zendesk_task.py @@ -10,12 +10,20 @@ class TestZendeskConnector: def test_connection(self, zendesk_runner: ConnectorRunner): zendesk_runner.test_connection() + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_access_request( self, + dsr_version, + request, zendesk_runner: ConnectorRunner, policy: Policy, zendesk_identity_email: str, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + access_results = await zendesk_runner.access_request( access_policy=policy, identities={"email": zendesk_identity_email} ) @@ -38,8 +46,14 @@ async def test_access_request( for ticket_comment in access_results["zendesk_instance:ticket_comments"]: assert ticket_comment["author_id"] == user_id + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_non_strict_erasure_request( self, + dsr_version, + request, zendesk_runner: ConnectorRunner, policy: Policy, erasure_policy_string_rewrite: Policy, @@ -47,6 +61,8 @@ async def test_non_strict_erasure_request( zendesk_erasure_data, zendesk_client, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + ( access_results, erasure_results, diff --git a/tests/ops/integration_tests/test_email_task.py b/tests/ops/integration_tests/test_email_task.py deleted file mode 100644 index 8b13789179..0000000000 --- a/tests/ops/integration_tests/test_email_task.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tests/ops/integration_tests/test_enabled_actions.py b/tests/ops/integration_tests/test_enabled_actions.py index c249ba012c..1276ed7556 100644 --- a/tests/ops/integration_tests/test_enabled_actions.py +++ b/tests/ops/integration_tests/test_enabled_actions.py @@ -4,9 +4,12 @@ from fides.api.graph.graph import DatasetGraph from fides.api.models.connectionconfig import ActionType from fides.api.models.datasetconfig import convert_dataset_to_graph -from fides.api.models.privacy_request import PrivacyRequest, PrivacyRequestStatus -from fides.api.task import graph_task -from fides.api.task.graph_task import get_cached_data_for_erasures +from fides.api.models.privacy_request import PrivacyRequestStatus +from fides.api.task.graph_runners import access_runner, erasure_runner +from fides.api.task.graph_task import ( + filter_by_enabled_actions, + get_cached_data_for_erasures, +) from tests.ops.integration_tests.saas.connector_runner import dataset_config from tests.ops.service.privacy_request.test_request_runner_service import ( get_privacy_request_results, @@ -31,20 +34,16 @@ def dataset_graph( @pytest.mark.asyncio async def test_access_disabled( - self, - db, - policy, - integration_postgres_config, - dataset_graph, + self, db, policy, integration_postgres_config, dataset_graph, privacy_request ) -> None: """Disable the access request for one connection config and verify the access results""" - + # Not testing this with both the DSR 2.0 and DSR 3.0 schedules because filtered_by_enabled_actions + # happens after the access section now # disable the access action type for Postgres integration_postgres_config.enabled_actions = [ActionType.erasure] integration_postgres_config.save(db) - privacy_request = PrivacyRequest(id="test_disable_postgres_access") - access_results = await graph_task.run_access_request( + access_runner( privacy_request, policy, dataset_graph, @@ -52,8 +51,12 @@ async def test_access_disabled( {"email": "customer-1@example.com"}, db, ) + raw_access_results = privacy_request.get_raw_access_results() + filtered_access_results = filter_by_enabled_actions( + raw_access_results, [integration_postgres_config] + ) - assert access_results == {} + assert filtered_access_results == {} @pytest.mark.asyncio async def test_erasure_disabled( @@ -63,16 +66,16 @@ async def test_erasure_disabled( erasure_policy, integration_postgres_config, dataset_graph, + privacy_request_with_erasure_policy, ) -> None: """Disable the erasure request for one connection config and verify the erasure results""" # disable the erasure action type for Postgres integration_postgres_config.enabled_actions = [ActionType.access] integration_postgres_config.save(db) - privacy_request = PrivacyRequest(id="test_disable_postgres_erasure") - access_results = await graph_task.run_access_request( - privacy_request, + access_results = access_runner( + privacy_request_with_erasure_policy, policy, dataset_graph, [integration_postgres_config], @@ -86,13 +89,13 @@ async def test_erasure_disabled( postgres_dataset, } - erasure_results = await graph_task.run_erasure( - privacy_request, + erasure_results = erasure_runner( + privacy_request_with_erasure_policy, erasure_policy, dataset_graph, [integration_postgres_config], {"email": "customer-1@example.com"}, - get_cached_data_for_erasures(privacy_request.id), + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), db, ) @@ -100,15 +103,23 @@ async def test_erasure_disabled( assert erasure_results == {} @pytest.mark.asyncio + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_access_disabled_for_manual_webhook_integrations( self, db, + dsr_version, + request, policy, integration_postgres_config, integration_manual_webhook_config, access_manual_webhook, run_privacy_request_task, ) -> None: + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + pr = get_privacy_request_results( db, policy, @@ -143,9 +154,15 @@ async def test_access_disabled_for_manual_webhook_integrations( assert pr.status == PrivacyRequestStatus.complete @pytest.mark.asyncio + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_erasure_disabled_for_manual_webhook_integrations( self, db, + dsr_version, + request, policy, erasure_policy, integration_postgres_config, @@ -153,6 +170,8 @@ async def test_erasure_disabled_for_manual_webhook_integrations( access_manual_webhook, run_privacy_request_task, ) -> None: + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + pr = get_privacy_request_results( db, erasure_policy, diff --git a/tests/ops/integration_tests/test_execution.py b/tests/ops/integration_tests/test_execution.py index 740b4efd4e..b6acb88521 100644 --- a/tests/ops/integration_tests/test_execution.py +++ b/tests/ops/integration_tests/test_execution.py @@ -1,4 +1,4 @@ -import uuid +from datetime import datetime from typing import Optional from unittest import mock @@ -23,11 +23,11 @@ ExecutionLog, PrivacyRequest, ) -from fides.api.task import graph_task from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG from tests.fixtures.application_fixtures import integration_secrets +from ...conftest import access_runner_tester, erasure_runner_tester from ..service.privacy_request.test_request_runner_service import ( get_privacy_request_results, ) @@ -65,9 +65,15 @@ class TestDeleteCollection: @pytest.mark.usefixtures( "postgres_integration_db", "postgres_example_test_dataset_config_read_access" ) + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) def test_delete_collection_before_new_request( self, db, + dsr_version, + request, policy, read_connection_config, run_privacy_request_task, @@ -75,7 +81,9 @@ def test_delete_collection_before_new_request( """Delete the connection config before execution starts which also deletes its dataset config. The graph is built with nothing in it, and no results are returned. """ - customer_email = "customer-1@example.com" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-4@example.com" data = { "requested_at": "2021-08-30T16:09:37.359Z", "policy_key": policy.key, @@ -100,6 +108,10 @@ def test_delete_collection_before_new_request( assert pr.get_results() == {} @mock.patch("fides.api.task.graph_task.GraphTask.log_start") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) @pytest.mark.asyncio async def test_delete_collection_while_in_progress( self, @@ -108,10 +120,15 @@ async def test_delete_collection_while_in_progress( policy, integration_postgres_config, example_datasets, + privacy_request, + dsr_version, + request, ) -> None: """Assert that deleting a collection while the privacy request is in progress doesn't affect the current execution plan. We still proceed to visit the deleted collections, because we rely on the ConnectionConfigs already in memory. """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + # Create a new ConnectionConfig instead of using the fixture because I need to be able to access this # outside of the current session. mongo_connection_config = ConnectionConfig( @@ -150,16 +167,13 @@ def delete_connection_config(_): new_session.close() mocked_log_start.side_effect = delete_connection_config - privacy_request = PrivacyRequest( - id=f"test_postgres_access_request_task_{uuid.uuid4()}" - ) - results = await graph_task.run_access_request( + results = access_runner_tester( privacy_request, policy, dataset_graph, [integration_postgres_config, mongo_connection_config], - {"email": "customer-1@example.com"}, + {"email": "customer-4@example.com"}, db, ) assert any( @@ -188,6 +202,10 @@ def delete_connection_config(_): db.delete(mongo_connection_config) @pytest.mark.asyncio + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_collection_omitted_on_restart_from_failure( self, db, @@ -196,45 +214,73 @@ async def test_collection_omitted_on_restart_from_failure( integration_mongodb_config, mongo_postgres_dataset_graph, example_datasets, - run_privacy_request_task, + privacy_request, + dsr_version, + request, ) -> None: """Remove secrets to make privacy request fail, then delete the connection config. Build a graph that does not contain the deleted dataset config and re-run.""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 integration_mongodb_config.secrets = {} integration_mongodb_config.save(db) - privacy_request = PrivacyRequest( - id=f"test_postgres_access_request_task_{uuid.uuid4()}" - ) + if "use_dsr_2_0" == dsr_version: + with pytest.raises(ValidationError): + access_runner_tester( + privacy_request, + policy, + mongo_postgres_dataset_graph, + [integration_postgres_config, integration_mongodb_config], + {"email": "customer-4@example.com"}, + db, + ) - with pytest.raises(ValidationError): - await graph_task.run_access_request( + execution_logs = get_sorted_execution_logs(db, privacy_request) + assert execution_logs == [ + ("postgres_example_test_dataset:customer", "in_processing"), + ("postgres_example_test_dataset:customer", "complete"), + ("postgres_example_test_dataset:payment_card", "in_processing"), + ("postgres_example_test_dataset:payment_card", "complete"), + ("postgres_example_test_dataset:orders", "in_processing"), + ("postgres_example_test_dataset:orders", "complete"), + ("postgres_example_test_dataset:order_item", "in_processing"), + ("postgres_example_test_dataset:order_item", "complete"), + ("postgres_example_test_dataset:product", "in_processing"), + ("postgres_example_test_dataset:product", "complete"), + ("mongo_test:customer_details", "in_processing"), + ("mongo_test:customer_details", "error"), + ], "Execution failed at first mongo collection" + + integration_mongodb_config.delete(db) + + else: + access_runner_tester( privacy_request, policy, mongo_postgres_dataset_graph, [integration_postgres_config, integration_mongodb_config], - {"email": "customer-1@example.com"}, + {"email": "customer-4@example.com"}, db, ) - - execution_logs = get_sorted_execution_logs(db, privacy_request) - assert execution_logs == [ - ("postgres_example_test_dataset:customer", "in_processing"), - ("postgres_example_test_dataset:customer", "complete"), - ("postgres_example_test_dataset:payment_card", "in_processing"), - ("postgres_example_test_dataset:payment_card", "complete"), - ("postgres_example_test_dataset:orders", "in_processing"), - ("postgres_example_test_dataset:orders", "complete"), - ("postgres_example_test_dataset:order_item", "in_processing"), - ("postgres_example_test_dataset:order_item", "complete"), - ("postgres_example_test_dataset:product", "in_processing"), - ("postgres_example_test_dataset:product", "complete"), - ("mongo_test:customer_details", "in_processing"), - ("mongo_test:customer_details", "error"), - ], "Execution failed at first mongo collection" - - integration_mongodb_config.delete(db) + customer_detail_logs = db.query(ExecutionLog).filter_by( + privacy_request_id=privacy_request.id, + dataset_name="mongo_test", + collection_name="customer_details", + ) + assert ["in_processing", "error"] == [ + log.status.value + for log in customer_detail_logs.order_by(ExecutionLog.created_at) + ] + customer_feedback_logs = db.query(ExecutionLog).filter_by( + privacy_request_id=privacy_request.id, + dataset_name="mongo_test", + collection_name="customer_feedback", + ) + assert ["in_processing", "error"] == [ + log.status.value + for log in customer_feedback_logs.order_by(ExecutionLog.created_at) + ] # Just rebuilding a graph without the deleted config. dataset_postgres = Dataset(**example_datasets[0]) @@ -243,59 +289,90 @@ async def test_collection_omitted_on_restart_from_failure( ) postgres_only_dataset_graph = DatasetGraph(*[graph]) - results = await graph_task.run_access_request( + results = access_runner_tester( privacy_request, policy, postgres_only_dataset_graph, [integration_postgres_config], - {"email": "customer-1@example.com"}, + {"email": "customer-4@example.com"}, db, ) - execution_logs = get_sorted_execution_logs(db, privacy_request) - assert execution_logs == [ - ("postgres_example_test_dataset:customer", "in_processing"), - ("postgres_example_test_dataset:customer", "complete"), - ("postgres_example_test_dataset:payment_card", "in_processing"), - ("postgres_example_test_dataset:payment_card", "complete"), - ("postgres_example_test_dataset:orders", "in_processing"), - ("postgres_example_test_dataset:orders", "complete"), - ("postgres_example_test_dataset:order_item", "in_processing"), - ("postgres_example_test_dataset:order_item", "complete"), - ("postgres_example_test_dataset:product", "in_processing"), - ("postgres_example_test_dataset:product", "complete"), - ("mongo_test:customer_details", "in_processing"), - ("mongo_test:customer_details", "error"), - ("postgres_example_test_dataset:employee", "in_processing"), - ("postgres_example_test_dataset:employee", "complete"), - ("postgres_example_test_dataset:service_request", "in_processing"), - ("postgres_example_test_dataset:service_request", "complete"), - ("postgres_example_test_dataset:report", "in_processing"), - ("postgres_example_test_dataset:report", "complete"), - ("postgres_example_test_dataset:visit", "in_processing"), - ("postgres_example_test_dataset:visit", "complete"), - ("postgres_example_test_dataset:address", "in_processing"), - ("postgres_example_test_dataset:address", "complete"), - ("postgres_example_test_dataset:login", "in_processing"), - ("postgres_example_test_dataset:login", "complete"), - ], "No mongo collections run" - - assert all( - [dataset.startswith("postgres_example") for dataset in results] - ), "No mongo results" + if "use_dsr_2_0" == dsr_version: + execution_logs = get_sorted_execution_logs(db, privacy_request) + assert execution_logs == [ + ("postgres_example_test_dataset:customer", "in_processing"), + ("postgres_example_test_dataset:customer", "complete"), + ("postgres_example_test_dataset:payment_card", "in_processing"), + ("postgres_example_test_dataset:payment_card", "complete"), + ("postgres_example_test_dataset:orders", "in_processing"), + ("postgres_example_test_dataset:orders", "complete"), + ("postgres_example_test_dataset:order_item", "in_processing"), + ("postgres_example_test_dataset:order_item", "complete"), + ("postgres_example_test_dataset:product", "in_processing"), + ("postgres_example_test_dataset:product", "complete"), + ("mongo_test:customer_details", "in_processing"), + ("mongo_test:customer_details", "error"), + ("postgres_example_test_dataset:employee", "in_processing"), + ("postgres_example_test_dataset:employee", "complete"), + ("postgres_example_test_dataset:service_request", "in_processing"), + ("postgres_example_test_dataset:service_request", "complete"), + ("postgres_example_test_dataset:report", "in_processing"), + ("postgres_example_test_dataset:report", "complete"), + ("postgres_example_test_dataset:visit", "in_processing"), + ("postgres_example_test_dataset:visit", "complete"), + ("postgres_example_test_dataset:address", "in_processing"), + ("postgres_example_test_dataset:address", "complete"), + ("postgres_example_test_dataset:login", "in_processing"), + ("postgres_example_test_dataset:login", "complete"), + ], "No mongo collections run" + + assert all( + [dataset.startswith("postgres_example") for dataset in results] + ), "No mongo results" + + else: + # For DSR 3.0 we don't rebuild the graph - we try to run the original graph that we saved + # to the database initially. These nodes try to run again and error. + customer_detail_logs = db.query(ExecutionLog).filter_by( + privacy_request_id=privacy_request.id, + dataset_name="mongo_test", + collection_name="customer_details", + ) + assert ["in_processing", "error", "in_processing", "error"] == [ + log.status.value + for log in customer_detail_logs.order_by(ExecutionLog.created_at) + ] + customer_feedback_logs = db.query(ExecutionLog).filter_by( + privacy_request_id=privacy_request.id, + dataset_name="mongo_test", + collection_name="customer_feedback", + ) + assert ["in_processing", "error", "in_processing", "error"] == [ + log.status.value + for log in customer_feedback_logs.order_by(ExecutionLog.created_at) + ] @pytest.mark.usefixtures( "postgres_integration_db", "postgres_example_test_dataset_config_read_access" ) + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) def test_delete_connection_config_on_completed_request( self, db, + dsr_version, + request, policy, read_connection_config, run_privacy_request_task, ) -> None: """Delete the connection config on a completed request leaves execution logs untouched""" - customer_email = "customer-1@example.com" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-4@example.com" data = { "requested_at": "2021-08-30T16:09:37.359Z", "policy_key": policy.key, @@ -320,6 +397,10 @@ def test_delete_connection_config_on_completed_request( @pytest.mark.integration class TestSkipCollectionDueToDisabledConnectionConfig: @pytest.mark.asyncio + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_skip_collection_new_request( self, db, @@ -327,23 +408,23 @@ async def test_skip_collection_new_request( integration_postgres_config, integration_mongodb_config, mongo_postgres_dataset_graph, + dsr_version, + request, + privacy_request, ) -> None: """Mark Mongo ConnectionConfig as disabled, run access request, and then assert that all mongo collections are skipped""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 integration_mongodb_config.disabled = True integration_mongodb_config.save(db) - privacy_request = PrivacyRequest( - id=f"test_postgres_access_request_task_{uuid.uuid4()}" - ) - - results = await graph_task.run_access_request( + results = access_runner_tester( privacy_request, policy, mongo_postgres_dataset_graph, [integration_postgres_config, integration_mongodb_config], - {"email": "customer-1@example.com"}, + {"email": "customer-4@example.com"}, db, ) assert all( @@ -366,17 +447,26 @@ async def test_skip_collection_new_request( @mock.patch("fides.api.task.graph_task.GraphTask.log_start") @pytest.mark.asyncio + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_run_disabled_collections_in_progress( self, mocked_log_start, db, policy, + privacy_request, integration_postgres_config, example_datasets, + dsr_version, + request, ) -> None: """Assert that disabling a collection while the privacy request is in progress can affect the current execution plan. ConnectionConfigs that are disabled while a request is in progress will be skipped after the current session is committed. """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + # Create a new ConnectionConfig instead of using the fixture because I need to be able to access this # outside of the current session. mongo_connection_config = ConnectionConfig( @@ -419,16 +509,12 @@ def disable_connection_config(_): mocked_log_start.side_effect = disable_connection_config - privacy_request = PrivacyRequest( - id=f"test_postgres_access_request_task_{uuid.uuid4()}" - ) - - results = await graph_task.run_access_request( + results = access_runner_tester( privacy_request, policy, dataset_graph, [integration_postgres_config, mongo_connection_config], - {"email": "customer-1@example.com"}, + {"email": "customer-4@example.com"}, db, ) assert not any( @@ -457,6 +543,10 @@ def disable_connection_config(_): db.delete(mongo_connection_config) @pytest.mark.asyncio + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_skip_collection_on_restart( self, db, @@ -464,108 +554,170 @@ async def test_skip_collection_on_restart( integration_postgres_config, integration_mongodb_config, mongo_postgres_dataset_graph, + privacy_request, + dsr_version, + request, ) -> None: """Remove secrets to make privacy request fail, then disable connection config and confirm that datastores are skipped on re-run""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 integration_mongodb_config.secrets = {} integration_mongodb_config.save(db) - privacy_request = PrivacyRequest( - id=f"test_postgres_access_request_task_{uuid.uuid4()}" - ) + if dsr_version == "use_dsr_2_0": + with pytest.raises(ValidationError): + access_runner_tester( + privacy_request, + policy, + mongo_postgres_dataset_graph, + [integration_postgres_config, integration_mongodb_config], + {"email": "customer-4@example.com"}, + db, + ) - with pytest.raises(ValidationError): - await graph_task.run_access_request( + execution_logs = get_sorted_execution_logs(db, privacy_request) + assert execution_logs == [ + ("postgres_example_test_dataset:customer", "in_processing"), + ("postgres_example_test_dataset:customer", "complete"), + ("postgres_example_test_dataset:payment_card", "in_processing"), + ("postgres_example_test_dataset:payment_card", "complete"), + ("postgres_example_test_dataset:orders", "in_processing"), + ("postgres_example_test_dataset:orders", "complete"), + ("postgres_example_test_dataset:order_item", "in_processing"), + ("postgres_example_test_dataset:order_item", "complete"), + ("postgres_example_test_dataset:product", "in_processing"), + ("postgres_example_test_dataset:product", "complete"), + ("mongo_test:customer_details", "in_processing"), + ("mongo_test:customer_details", "error"), + ], "Execution failed at first mongo collection" + + else: + access_runner_tester( privacy_request, policy, mongo_postgres_dataset_graph, [integration_postgres_config, integration_mongodb_config], - {"email": "customer-1@example.com"}, + {"email": "customer-4@example.com"}, db, ) - execution_logs = get_sorted_execution_logs(db, privacy_request) - assert execution_logs == [ - ("postgres_example_test_dataset:customer", "in_processing"), - ("postgres_example_test_dataset:customer", "complete"), - ("postgres_example_test_dataset:payment_card", "in_processing"), - ("postgres_example_test_dataset:payment_card", "complete"), - ("postgres_example_test_dataset:orders", "in_processing"), - ("postgres_example_test_dataset:orders", "complete"), - ("postgres_example_test_dataset:order_item", "in_processing"), - ("postgres_example_test_dataset:order_item", "complete"), - ("postgres_example_test_dataset:product", "in_processing"), - ("postgres_example_test_dataset:product", "complete"), - ("mongo_test:customer_details", "in_processing"), - ("mongo_test:customer_details", "error"), - ], "Execution failed at first mongo collection" + # DSR 3.0 can run multiple nodes in parallel - two mongo nodes were able to attempt to run + # before failing and blocking downstream nodes + customer_detail_logs = db.query(ExecutionLog).filter_by( + privacy_request_id=privacy_request.id, + dataset_name="mongo_test", + collection_name="customer_details", + ) + assert ["in_processing", "error"] == [ + log.status.value + for log in customer_detail_logs.order_by(ExecutionLog.created_at) + ] + customer_feedback_logs = db.query(ExecutionLog).filter_by( + privacy_request_id=privacy_request.id, + dataset_name="mongo_test", + collection_name="customer_feedback", + ) + assert ["in_processing", "error"] == [ + log.status.value + for log in customer_feedback_logs.order_by(ExecutionLog.created_at) + ] integration_mongodb_config.disabled = True integration_mongodb_config.save(db) - results = await graph_task.run_access_request( + results = access_runner_tester( privacy_request, policy, mongo_postgres_dataset_graph, [integration_postgres_config, integration_mongodb_config], - {"email": "customer-1@example.com"}, + {"email": "customer-4@example.com"}, db, ) - execution_logs = get_sorted_execution_logs(db, privacy_request) - assert execution_logs == [ - ("postgres_example_test_dataset:customer", "in_processing"), - ("postgres_example_test_dataset:customer", "complete"), - ("postgres_example_test_dataset:payment_card", "in_processing"), - ("postgres_example_test_dataset:payment_card", "complete"), - ("postgres_example_test_dataset:orders", "in_processing"), - ("postgres_example_test_dataset:orders", "complete"), - ("postgres_example_test_dataset:order_item", "in_processing"), - ("postgres_example_test_dataset:order_item", "complete"), - ("postgres_example_test_dataset:product", "in_processing"), - ("postgres_example_test_dataset:product", "complete"), - ("mongo_test:customer_details", "in_processing"), - ("mongo_test:customer_details", "error"), - ("postgres_example_test_dataset:employee", "in_processing"), - ("postgres_example_test_dataset:employee", "complete"), - ("postgres_example_test_dataset:service_request", "in_processing"), - ("postgres_example_test_dataset:service_request", "complete"), - ("mongo_test:customer_feedback", "skipped"), - ("mongo_test:internal_customer_profile", "skipped"), - ("mongo_test:rewards", "skipped"), - ("postgres_example_test_dataset:report", "in_processing"), - ("postgres_example_test_dataset:report", "complete"), - ("postgres_example_test_dataset:visit", "in_processing"), - ("postgres_example_test_dataset:visit", "complete"), - ("postgres_example_test_dataset:address", "in_processing"), - ("postgres_example_test_dataset:address", "complete"), - ("mongo_test:customer_details", "skipped"), - ("mongo_test:flights", "skipped"), - ("mongo_test:employee", "skipped"), - ("mongo_test:aircraft", "skipped"), - ("mongo_test:conversations", "skipped"), - ("mongo_test:payment_card", "skipped"), - ("postgres_example_test_dataset:login", "in_processing"), - ("postgres_example_test_dataset:login", "complete"), - ], "Rerun skips disabled collections" - - assert all( - [dataset.startswith("postgres_example") for dataset in results] - ), "No mongo results" + if dsr_version == "use_dsr_2_0": + execution_logs = get_sorted_execution_logs(db, privacy_request) + assert execution_logs == [ + ("postgres_example_test_dataset:customer", "in_processing"), + ("postgres_example_test_dataset:customer", "complete"), + ("postgres_example_test_dataset:payment_card", "in_processing"), + ("postgres_example_test_dataset:payment_card", "complete"), + ("postgres_example_test_dataset:orders", "in_processing"), + ("postgres_example_test_dataset:orders", "complete"), + ("postgres_example_test_dataset:order_item", "in_processing"), + ("postgres_example_test_dataset:order_item", "complete"), + ("postgres_example_test_dataset:product", "in_processing"), + ("postgres_example_test_dataset:product", "complete"), + ("mongo_test:customer_details", "in_processing"), + ("mongo_test:customer_details", "error"), + ("postgres_example_test_dataset:employee", "in_processing"), + ("postgres_example_test_dataset:employee", "complete"), + ("postgres_example_test_dataset:service_request", "in_processing"), + ("postgres_example_test_dataset:service_request", "complete"), + ("mongo_test:customer_feedback", "skipped"), + ("mongo_test:internal_customer_profile", "skipped"), + ("mongo_test:rewards", "skipped"), + ("postgres_example_test_dataset:report", "in_processing"), + ("postgres_example_test_dataset:report", "complete"), + ("postgres_example_test_dataset:visit", "in_processing"), + ("postgres_example_test_dataset:visit", "complete"), + ("postgres_example_test_dataset:address", "in_processing"), + ("postgres_example_test_dataset:address", "complete"), + ("mongo_test:customer_details", "skipped"), + ("mongo_test:flights", "skipped"), + ("mongo_test:employee", "skipped"), + ("mongo_test:aircraft", "skipped"), + ("mongo_test:conversations", "skipped"), + ("mongo_test:payment_card", "skipped"), + ("postgres_example_test_dataset:login", "in_processing"), + ("postgres_example_test_dataset:login", "complete"), + ], "Rerun skips disabled collections" + + assert all( + [dataset.startswith("postgres_example") for dataset in results] + ), "No mongo results" + + else: + # Rerun also skips disabled collections for DSR 3.0 + customer_detail_logs = db.query(ExecutionLog).filter_by( + privacy_request_id=privacy_request.id, + dataset_name="mongo_test", + collection_name="customer_details", + ) + assert ["in_processing", "error", "skipped"] == [ + log.status.value + for log in customer_detail_logs.order_by(ExecutionLog.created_at) + ] + customer_feedback_logs = db.query(ExecutionLog).filter_by( + privacy_request_id=privacy_request.id, + dataset_name="mongo_test", + collection_name="customer_feedback", + ) + assert ["in_processing", "error", "skipped"] == [ + log.status.value + for log in customer_feedback_logs.order_by(ExecutionLog.created_at) + ] @pytest.mark.usefixtures( "postgres_integration_db", "postgres_example_test_dataset_config_read_access" ) + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) def test_disable_connection_config_on_completed_request( self, db, policy, read_connection_config, run_privacy_request_task, + dsr_version, + request, ) -> None: """Disabling the connection config on a completed request leaves execution logs untouched""" - customer_email = "customer-1@example.com" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + customer_email = "customer-4@example.com" data = { "requested_at": "2021-08-30T16:09:37.359Z", "policy_key": policy.key, @@ -597,6 +749,7 @@ def _build_postgres_dataset_graph_with_skipped_collection( skipped_collection_name: Optional[str], ): """test helper""" + dataset_postgres = Dataset(**example_datasets[0]) if skipped_collection_name: skipped_collection = next( @@ -612,29 +765,33 @@ def _build_postgres_dataset_graph_with_skipped_collection( return dataset_graph @pytest.mark.asyncio + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_no_collections_marked_as_skipped( self, db, policy, example_datasets, + dsr_version, + request, integration_postgres_config, + privacy_request, ) -> None: """Sanity check - nothing marked as skipped. All collections expected in results.""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 postgres_graph = self._build_postgres_dataset_graph_with_skipped_collection( example_datasets, integration_postgres_config, skipped_collection_name=None ) - privacy_request = PrivacyRequest( - id=f"test_postgres_access_request_task_{uuid.uuid4()}" - ) - - results = await graph_task.run_access_request( + results = access_runner_tester( privacy_request, policy, postgres_graph, [integration_postgres_config], - {"email": "customer-1@example.com"}, + {"email": "customer-4@example.com"}, db, ) @@ -642,14 +799,22 @@ async def test_no_collections_marked_as_skipped( assert "login" not in results @pytest.mark.asyncio + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_collection_marked_as_skipped_with_nothing_downstream( self, db, policy, example_datasets, + privacy_request, + dsr_version, + request, integration_postgres_config, ) -> None: """Mark the login collection as skipped. This collection has no downstream dependencies, so skipping is fine!""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 postgres_graph = self._build_postgres_dataset_graph_with_skipped_collection( example_datasets, @@ -657,16 +822,12 @@ async def test_collection_marked_as_skipped_with_nothing_downstream( skipped_collection_name="login", ) - privacy_request = PrivacyRequest( - id=f"test_postgres_access_request_task_{uuid.uuid4()}" - ) - - results = await graph_task.run_access_request( + results = access_runner_tester( privacy_request, policy, postgres_graph, [integration_postgres_config], - {"email": "customer-1@example.com"}, + {"email": "customer-4@example.com"}, db, ) @@ -674,19 +835,23 @@ async def test_collection_marked_as_skipped_with_nothing_downstream( assert "login" not in results @pytest.mark.asyncio + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_collection_marked_as_skipped_with_dependencies( self, db, policy, + privacy_request, example_datasets, + dsr_version, + request, integration_postgres_config, ) -> None: """Mark the address collection as skipped. Many collections are marked as relying on this collection so this fails early when building the DatasetGraph""" - - privacy_request = PrivacyRequest( - id=f"test_postgres_access_request_task_{uuid.uuid4()}" - ) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 with pytest.raises(common_exceptions.ValidationError): postgres_graph = self._build_postgres_dataset_graph_with_skipped_collection( @@ -695,18 +860,24 @@ async def test_collection_marked_as_skipped_with_dependencies( skipped_collection_name="address", ) - await graph_task.run_access_request( - privacy_request, - policy, - postgres_graph, - [integration_postgres_config], - {"email": "customer-1@example.com"}, - db, + access_runner_tester( + ( + privacy_request, + policy, + postgres_graph, + [integration_postgres_config], + {"email": "customer-4@example.com"}, + db, + ) ) @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_restart_graph_from_failure( db, policy, @@ -714,92 +885,212 @@ async def test_restart_graph_from_failure( integration_postgres_config, integration_mongodb_config, mongo_postgres_dataset_graph, + privacy_request, + dsr_version, + request, ) -> None: """Run a failed privacy request and restart from failure""" - - privacy_request = PrivacyRequest( - id=f"test_postgres_access_request_task_{uuid.uuid4()}" - ) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 # Temporarily remove the secrets from the mongo connection to prevent execution from occurring saved_secrets = integration_mongodb_config.secrets integration_mongodb_config.secrets = {} integration_mongodb_config.save(db) - # Attempt to run the graph; execution will stop when we reach one of the mongo nodes - with pytest.raises(Exception) as exc: - await graph_task.run_access_request( + # Attempt to run the graph; execution will stop when we reach one of the mongo nodes for DSR 2.0 + if dsr_version == "use_dsr_2_0": + with pytest.raises(Exception) as exc: + access_runner_tester( + privacy_request, + policy, + mongo_postgres_dataset_graph, + [integration_postgres_config, integration_mongodb_config], + {"email": "customer-4@example.com"}, + db, + ) + assert exc.value.__class__ == ValidationError + assert ( + "MongoDBSchema must be supplied all of: ['host', 'username', 'password', 'defaultauthdb']" + in str(exc.value) + ) + + execution_logs = get_sorted_execution_logs(db, privacy_request) + # Assert execution logs failed at mongo node + assert execution_logs == [ + ("postgres_example_test_dataset:customer", "in_processing"), + ("postgres_example_test_dataset:customer", "complete"), + ("postgres_example_test_dataset:payment_card", "in_processing"), + ("postgres_example_test_dataset:payment_card", "complete"), + ("postgres_example_test_dataset:orders", "in_processing"), + ("postgres_example_test_dataset:orders", "complete"), + ("postgres_example_test_dataset:order_item", "in_processing"), + ("postgres_example_test_dataset:order_item", "complete"), + ("postgres_example_test_dataset:product", "in_processing"), + ("postgres_example_test_dataset:product", "complete"), + ("mongo_test:customer_details", "in_processing"), + ("mongo_test:customer_details", "error"), + ] + else: + access_runner_tester( privacy_request, policy, mongo_postgres_dataset_graph, [integration_postgres_config, integration_mongodb_config], - {"email": "customer-1@example.com"}, + {"email": "customer-4@example.com"}, db, ) - assert exc.value.__class__ == ValidationError - assert ( - "MongoDBSchema must be supplied all of: ['host', 'username', 'password', 'defaultauthdb']" - in str(exc.value) - ) + # Multiple mongo level nodes attempted to run in DSR 3.0 before failing and blocking downstream nodes + customer_detail_logs = db.query(ExecutionLog).filter_by( + privacy_request_id=privacy_request.id, + dataset_name="mongo_test", + collection_name="customer_details", + ) + assert ["in_processing", "error"] == [ + log.status.value + for log in customer_detail_logs.order_by(ExecutionLog.created_at) + ] + customer_feedback_logs = db.query(ExecutionLog).filter_by( + privacy_request_id=privacy_request.id, + dataset_name="mongo_test", + collection_name="customer_feedback", + ) + assert ["in_processing", "error"] == [ + log.status.value + for log in customer_feedback_logs.order_by(ExecutionLog.created_at) + ] - execution_logs = get_sorted_execution_logs(db, privacy_request) - # Assert execution logs failed at mongo node - assert execution_logs == [ - ("postgres_example_test_dataset:customer", "in_processing"), - ("postgres_example_test_dataset:customer", "complete"), - ("postgres_example_test_dataset:payment_card", "in_processing"), - ("postgres_example_test_dataset:payment_card", "complete"), - ("postgres_example_test_dataset:orders", "in_processing"), - ("postgres_example_test_dataset:orders", "complete"), - ("postgres_example_test_dataset:order_item", "in_processing"), - ("postgres_example_test_dataset:order_item", "complete"), - ("postgres_example_test_dataset:product", "in_processing"), - ("postgres_example_test_dataset:product", "complete"), - ("mongo_test:customer_details", "in_processing"), - ("mongo_test:customer_details", "error"), - ] assert privacy_request.get_failed_checkpoint_details() == CheckpointActionRequired( step=CurrentStep.access, - collection=CollectionAddress("mongo_test", "customer_details"), ) # Reset secrets integration_mongodb_config.secrets = saved_secrets integration_mongodb_config.save(db) - # Rerun access request using cached results - with mock.patch("fides.api.task.graph_task.fideslog_graph_rerun") as mock_log_event: - await graph_task.run_access_request( + access_runner_tester( + privacy_request, + policy, + mongo_postgres_dataset_graph, + [integration_postgres_config, integration_mongodb_config], + {"email": "customer-4@example.com"}, + db, + ) + + assert ( + db.query(ExecutionLog) + .filter_by( + privacy_request_id=privacy_request.id, + dataset_name="postgres_example_test_dataset", + collection_name="customer", + ) + .count() + == 2 + ), "Postgres customer collection does not re-run" + + assert db.query(ExecutionLog).filter_by( + privacy_request_id=privacy_request.id, + dataset_name="mongo_test", + collection_name="customer_details", + ) + + customer_detail_logs = [ + ( + CollectionAddress(log.dataset_name, log.collection_name).value, + log.status.value, + ) + for log in db.query(ExecutionLog) + .filter_by( + privacy_request_id=privacy_request.id, + dataset_name="mongo_test", + collection_name="customer_details", + ) + .order_by("created_at") + ] + + assert customer_detail_logs == [ + ("mongo_test:customer_details", "in_processing"), + ("mongo_test:customer_details", "error"), + ("mongo_test:customer_details", "in_processing"), + ("mongo_test:customer_details", "complete"), + ], "Mongo customer_details node reruns" + + +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +async def test_restart_graph_from_failure_on_different_scheduler( + db, + policy, + example_datasets, + integration_postgres_config, + integration_mongodb_config, + mongo_postgres_dataset_graph, + privacy_request, + dsr_version, + request, +) -> None: + """Run a failed privacy request and restart from failure""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + # Temporarily remove the secrets from the mongo connection to prevent execution from occurring + saved_secrets = integration_mongodb_config.secrets + integration_mongodb_config.secrets = {} + integration_mongodb_config.save(db) + + # Attempt to run the graph; execution will stop when we reach one of the mongo nodes for DSR 2.0 + if dsr_version == "use_dsr_2_0": + with pytest.raises(Exception) as exc: + access_runner_tester( + privacy_request, + policy, + mongo_postgres_dataset_graph, + [integration_postgres_config, integration_mongodb_config], + {"email": "customer-4@example.com"}, + db, + ) + else: + access_runner_tester( privacy_request, policy, mongo_postgres_dataset_graph, [integration_postgres_config, integration_mongodb_config], - {"email": "customer-1@example.com"}, + {"email": "customer-4@example.com"}, db, ) - # Assert analytics event created - before and after graph on rerun did not change - analytics_event = mock_log_event.call_args.args[0] - assert analytics_event.docker is True - assert analytics_event.event == "rerun_access_graph" - assert analytics_event.event_created_at is not None - assert analytics_event.extra_data == { - "prev_collection_count": 20, - "curr_collection_count": 20, - "added_collection_count": 0, - "removed_collection_count": 0, - "added_edge_count": 0, - "removed_edge_count": 0, - "already_processed_access_collection_count": 5, - "already_processed_erasure_collection_count": 0, - "skipped_added_edge_count": 0, - "privacy_request": privacy_request.id, - } + assert privacy_request.get_failed_checkpoint_details() == CheckpointActionRequired( + step=CurrentStep.access, + ) + + # Test switching the version from when the Privacy Request was first run + if dsr_version == "use_dsr_3_0": + original_version = 3.0 + CONFIG.execution.use_dsr_3_0 = False + else: + original_version = 2.0 + CONFIG.execution.use_dsr_3_0 = True - assert analytics_event.error is None - assert analytics_event.status_code is None - assert analytics_event.endpoint is None - assert analytics_event.local_host is None + # Reset secrets + integration_mongodb_config.secrets = saved_secrets + integration_mongodb_config.save(db) + + access_runner_tester( + privacy_request, + policy, + mongo_postgres_dataset_graph, + [integration_postgres_config, integration_mongodb_config], + {"email": "customer-4@example.com"}, + db, + ) + + db.refresh(privacy_request) + if original_version == 2.0: + assert not privacy_request.access_tasks.count() + else: + assert privacy_request.access_tasks.count() assert ( db.query(ExecutionLog) @@ -842,99 +1133,117 @@ async def test_restart_graph_from_failure( @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_restart_graph_from_failure_during_erasure( db, - policy, - example_datasets, + erasure_policy, integration_postgres_config, integration_mongodb_config, mongo_postgres_dataset_graph, + dsr_version, + request, + privacy_request_with_erasure_policy, ) -> None: """Run a failed privacy request and restart from failure during the erasure portion. An erasure request first runs an access and then an erasure request. If the erasure portion fails, and we reprocess, we don't re-run the access portion currently. """ - - privacy_request = PrivacyRequest( - id=f"test_postgres_access_request_task_{uuid.uuid4()}" - ) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 # Run access portion like normal - await graph_task.run_access_request( - privacy_request, - policy, + access_runner_tester( + privacy_request_with_erasure_policy, + erasure_policy, mongo_postgres_dataset_graph, [integration_postgres_config, integration_mongodb_config], - {"email": "customer-1@example.com"}, + {"email": "customer-4@example.com"}, db, ) - - # Temporarily remove the secrets from the postgres connection to prevent execution from occurring - saved_secrets = integration_postgres_config.secrets - integration_postgres_config.secrets = {} - integration_postgres_config.save(db) - - # Attempt to run the erasure graph; execution will stop when we reach one of the mongo nodes - with pytest.raises(Exception) as exc: - await graph_task.run_erasure( - privacy_request, - policy, - mongo_postgres_dataset_graph, - [integration_postgres_config, integration_mongodb_config], - {"email": "customer-1@example.com"}, - get_cached_data_for_erasures(privacy_request.id), - db, - ) - assert exc.value.__class__ == ValidationError - assert ( - exc.value.errors()[0]["msg"] - == "PostgreSQLSchema must be supplied a 'url' or all of: ['host']." + assert [("access", "in_processing"), ("access", "complete")] == [ + (c.action_type.value, c.status.value) + for c in db.query(ExecutionLog) + .filter_by( + privacy_request_id=privacy_request_with_erasure_policy.id, + collection_name="address", ) + .order_by(ExecutionLog.created_at) + .all() + ] - # Reset secrets - integration_postgres_config.secrets = saved_secrets - integration_postgres_config.save(db) - - # Rerun erasure portion of request using cached results - with mock.patch("fides.api.task.graph_task.fideslog_graph_rerun") as mock_log_event: - await graph_task.run_erasure( - privacy_request, - policy, + saved_secrets = {} + for cc in db.query(ConnectionConfig).filter( + ConnectionConfig.connection_type == ConnectionType.postgres + ): + saved_secrets[cc.key] = cc.secrets.copy() + cc.secrets = None + cc.created_at = datetime.now() + cc.save(db) + db.commit() + + # Attempt to run the erasure graph; execution will stop when we reach one of the postgres nodes + if dsr_version == "use_dsr_2_0": + with pytest.raises(Exception) as exc: + erasure_runner_tester( + privacy_request_with_erasure_policy, + erasure_policy, + mongo_postgres_dataset_graph, + [integration_postgres_config, integration_mongodb_config], + {"email": "customer-4@example.com"}, + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), + db, + ) + assert exc.value.__class__ == ValidationError + assert ( + exc.value.errors()[0]["msg"] + == "PostgreSQLSchema must be supplied a 'url' or all of: ['host']." + ) + else: + # DSR 3.0 does not fail the entire privacy request as a whole by raising an exception. + # An AP Scheduler will come along and mark it as failed later + erasure_runner_tester( + privacy_request_with_erasure_policy, + erasure_policy, mongo_postgres_dataset_graph, [integration_postgres_config, integration_mongodb_config], - {"email": "customer-1@example.com"}, - get_cached_data_for_erasures(privacy_request.id), + {"email": "customer-4@example.com"}, + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), db, ) + assert ["in_processing", "complete", "in_processing", "error"] == [ + c.status.value + for c in db.query(ExecutionLog) + .filter_by( + privacy_request_id=privacy_request_with_erasure_policy.id, + collection_name="address", + ) + .order_by(ExecutionLog.created_at) + .all() + ] - # Assert analytics event created - before and after graph on rerun did not change - analytics_event = mock_log_event.call_args.args[0] - assert analytics_event.docker is True - assert analytics_event.event == "rerun_erasure_graph" - assert analytics_event.event_created_at is not None - assert analytics_event.extra_data == { - "prev_collection_count": 20, - "curr_collection_count": 20, - "added_collection_count": 0, - "removed_collection_count": 0, - "added_edge_count": 0, - "removed_edge_count": 0, - "already_processed_access_collection_count": 20, - "already_processed_erasure_collection_count": 9, - "skipped_added_edge_count": 0, - "privacy_request": privacy_request.id, - } + for config in db.query(ConnectionConfig).filter( + ConnectionConfig.connection_type == ConnectionType.postgres + ): + config.secrets = saved_secrets[config.key] + config.save(db) - assert analytics_event.error is None - assert analytics_event.status_code is None - assert analytics_event.endpoint is None - assert analytics_event.local_host is None + erasure_runner_tester( + privacy_request_with_erasure_policy, + erasure_policy, + mongo_postgres_dataset_graph, + [integration_postgres_config, integration_mongodb_config], + {"email": "customer-4@example.com"}, + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), + db, + ) assert ( db.query(ExecutionLog) .filter_by( - privacy_request_id=privacy_request.id, + privacy_request_id=privacy_request_with_erasure_policy.id, dataset_name="mongo_test", collection_name="customer_details", ) @@ -949,7 +1258,10 @@ async def test_restart_graph_from_failure_during_erasure( log.status.value, ) for log in db.query(ExecutionLog) - .filter_by(privacy_request_id=privacy_request.id, collection_name="address") + .filter_by( + privacy_request_id=privacy_request_with_erasure_policy.id, + collection_name="address", + ) .order_by("created_at") ] diff --git a/tests/ops/integration_tests/test_integration_attentive.py b/tests/ops/integration_tests/test_integration_attentive.py index 3520cfcacb..b4152946d3 100644 --- a/tests/ops/integration_tests/test_integration_attentive.py +++ b/tests/ops/integration_tests/test_integration_attentive.py @@ -25,10 +25,16 @@ "fides.api.service.privacy_request.email_batch_service.requeue_privacy_requests_after_email_send", ) @mock.patch("fides.api.service.messaging.message_dispatch_service._mailgun_dispatcher") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_erasure_email( mock_mailgun_dispatcher: Mock, mock_requeue_privacy_requests: Mock, db, + dsr_version, + request, erasure_policy, attentive_email_connection_config, run_privacy_request_task, @@ -40,6 +46,7 @@ async def test_erasure_email( Verify the privacy request is set to "awaiting email send" and that one email is sent when the send_email_batch job is executed manually """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 pr = get_privacy_request_results( db, @@ -91,10 +98,16 @@ async def test_erasure_email( "fides.api.service.privacy_request.email_batch_service.requeue_privacy_requests_after_email_send", ) @mock.patch("fides.api.service.messaging.message_dispatch_service._mailgun_dispatcher") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_erasure_email_no_messaging_config( mock_mailgun_dispatcher: Mock, mock_requeue_privacy_requests: Mock, db, + dsr_version, + request, erasure_policy, attentive_email_connection_config, run_privacy_request_task, @@ -105,6 +118,7 @@ async def test_erasure_email_no_messaging_config( Verify the privacy request is set to "awaiting email send" and that the email fails to send because of the missing messaging config. """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 pr = get_privacy_request_results( db, @@ -134,8 +148,14 @@ async def test_erasure_email_no_messaging_config( @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_erasure_email_no_write_permissions( db, + dsr_version, + request, erasure_policy, attentive_email_connection_config, run_privacy_request_task, @@ -145,6 +165,7 @@ async def test_erasure_email_no_write_permissions( Run an erasure privacy request with only an email (Attentive) connector. Verify we don't send an email for a connector with read-only access. """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 attentive_email_connection_config.update( db=db, @@ -172,8 +193,14 @@ async def test_erasure_email_no_write_permissions( @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_erasure_email_no_updates_needed( db, + dsr_version, + request, policy, attentive_email_connection_config, run_privacy_request_task, @@ -184,6 +211,7 @@ async def test_erasure_email_no_updates_needed( Verify the privacy request is set to "complete" because this is an access request and no erasures are needed. """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 pr = get_privacy_request_results( db, @@ -210,10 +238,16 @@ async def test_erasure_email_no_updates_needed( "fides.api.service.privacy_request.email_batch_service.requeue_privacy_requests_after_email_send", ) @mock.patch("fides.api.service.messaging.message_dispatch_service._mailgun_dispatcher") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_erasure_email_disabled_connector( mock_mailgun_dispatcher: Mock, mock_requeue_privacy_requests: Mock, db, + dsr_version, + request, erasure_policy, attentive_email_connection_config, run_privacy_request_task, @@ -225,6 +259,7 @@ async def test_erasure_email_disabled_connector( Verify the privacy request is set to "awaiting email send" and that one email is sent when the send_email_batch job is executed manually """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 attentive_email_connection_config.update( db=db, @@ -255,10 +290,16 @@ async def test_erasure_email_disabled_connector( "fides.api.service.privacy_request.email_batch_service.requeue_privacy_requests_after_email_send", ) @mock.patch("fides.api.service.messaging.message_dispatch_service._mailgun_dispatcher") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_erasure_email_unsupported_identity( mock_mailgun_dispatcher: Mock, mock_requeue_privacy_requests: Mock, db, + dsr_version, + request, erasure_policy, attentive_email_connection_config, run_privacy_request_task, @@ -269,6 +310,7 @@ async def test_erasure_email_unsupported_identity( Run an erasure privacy request with only an email (Attentive) connector. Verify the privacy request is set to "complete" because the provided identities are not supported. """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 pr = get_privacy_request_results( db, diff --git a/tests/ops/integration_tests/test_integration_custom_privacy_request_fields.py b/tests/ops/integration_tests/test_integration_custom_privacy_request_fields.py index acb306a07a..71f3a1ea46 100644 --- a/tests/ops/integration_tests/test_integration_custom_privacy_request_fields.py +++ b/tests/ops/integration_tests/test_integration_custom_privacy_request_fields.py @@ -82,10 +82,22 @@ def custom_privacy_request_fields_dataset( "allow_custom_privacy_request_field_collection_enabled", "allow_custom_privacy_request_fields_in_request_execution_enabled", ) + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") def test_custom_privacy_request_fields_access( - self, mock_send, db: Session, policy: Policy, run_privacy_request_task + self, + mock_send, + dsr_version, + request, + db: Session, + policy: Policy, + run_privacy_request_task, ) -> None: + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + data = { "requested_at": "2021-08-30T16:09:37.359Z", "policy_key": policy.key, @@ -133,10 +145,22 @@ def test_custom_privacy_request_fields_access( "allow_custom_privacy_request_field_collection_enabled", "allow_custom_privacy_request_fields_in_request_execution_enabled", ) + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") def test_custom_privacy_request_fields_erasure( - self, mock_send, db: Session, erasure_policy: Policy, run_privacy_request_task + self, + mock_send, + dsr_version, + request, + db: Session, + erasure_policy: Policy, + run_privacy_request_task, ) -> None: + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + data = { "requested_at": "2021-08-30T16:09:37.359Z", "policy_key": erasure_policy.key, diff --git a/tests/ops/integration_tests/test_integration_erasure_order.py b/tests/ops/integration_tests/test_integration_erasure_order.py index bc1f30e623..2e11770b9a 100644 --- a/tests/ops/integration_tests/test_integration_erasure_order.py +++ b/tests/ops/integration_tests/test_integration_erasure_order.py @@ -1,4 +1,3 @@ -import random from typing import Any, Dict, List from unittest import mock @@ -17,10 +16,11 @@ SaaSRequestType, register, ) -from fides.api.task import graph_task +from fides.api.task.graph_runners import access_runner, erasure_runner from fides.api.task.graph_task import get_cached_data_for_erasures from fides.api.util.collection_util import Row from fides.config import get_config +from tests.conftest import access_runner_tester, erasure_runner_tester from tests.ops.graph.graph_test_util import assert_rows_match CONFIG = get_config() @@ -66,16 +66,20 @@ def delete_no_op( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.usefixtures("use_dsr_2_0") async def test_saas_erasure_order_request_task( db, policy, + privacy_request, erasure_policy_complete_mask, saas_erasure_order_connection_config, saas_erasure_order_dataset_config, ) -> None: - privacy_request = PrivacyRequest( - id=f"test_saas_erasure_order_request_task_{random.randint(0, 1000)}" - ) + """This test uses DSR 2.0 specifically. Equivalent concept for DSR 3.0 tested + in test_create_request_tasks.py""" + privacy_request.policy_id = erasure_policy_complete_mask.id + privacy_request.save(db) + identity_attribute = "email" identity_value = "test@ethyca.com" identity_kwargs = {identity_attribute: identity_value} @@ -86,9 +90,9 @@ async def test_saas_erasure_order_request_task( merged_graph = saas_erasure_order_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner( privacy_request, - policy, + erasure_policy_complete_mask, graph, [saas_erasure_order_connection_config], {"email": "test@ethyca.com"}, @@ -113,7 +117,7 @@ async def test_saas_erasure_order_request_task( temp_masking = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False - x = await graph_task.run_erasure( + x = erasure_runner( privacy_request, erasure_policy_complete_mask, graph, @@ -156,17 +160,20 @@ async def test_saas_erasure_order_request_task( @pytest.mark.integration_saas @pytest.mark.asyncio +@pytest.mark.usefixtures("use_dsr_2_0") async def test_saas_erasure_order_request_task_with_cycle( db, - policy, + privacy_request, erasure_policy_complete_mask, saas_erasure_order_config, saas_erasure_order_connection_config, saas_erasure_order_dataset_config, ) -> None: - privacy_request = PrivacyRequest( - id=f"test_saas_erasure_order_request_task_with_cycle_{random.randint(0, 1000)}" - ) + """This test uses DSR 2.0 specifically. Equivalent concept for DSR 3.0 tested + in test_create_request_tasks.py""" + privacy_request.policy_id = erasure_policy_complete_mask.id + privacy_request.save(db) + identity_attribute = "email" identity_value = "test@ethyca.com" identity_kwargs = {identity_attribute: identity_value} @@ -185,9 +192,9 @@ async def test_saas_erasure_order_request_task_with_cycle( merged_graph = saas_erasure_order_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner( privacy_request, - policy, + erasure_policy_complete_mask, graph, [saas_erasure_order_connection_config], {"email": "test@ethyca.com"}, @@ -213,7 +220,7 @@ async def test_saas_erasure_order_request_task_with_cycle( CONFIG.execution.masking_strict = False with pytest.raises(TraversalError) as exc: - await graph_task.run_erasure( + erasure_runner( privacy_request, erasure_policy_complete_mask, graph, @@ -234,6 +241,10 @@ async def test_saas_erasure_order_request_task_with_cycle( @pytest.mark.integration_saas @pytest.mark.asyncio @mock.patch("fides.api.service.connectors.saas_connector.SaaSConnector.mask_data") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_saas_erasure_order_request_task_resume_from_error( mock_mask_data, db, @@ -241,10 +252,16 @@ async def test_saas_erasure_order_request_task_resume_from_error( erasure_policy_complete_mask, saas_erasure_order_connection_config, saas_erasure_order_dataset_config, + privacy_request, + dsr_version, + request, ) -> None: - privacy_request = PrivacyRequest( - id=f"test_saas_erasure_order_request_task_resume_from_error_{random.randint(0, 1000)}" - ) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + # Policy needs to actually be set correctly on the privacy request for 3.0 testing + privacy_request.policy_id = erasure_policy_complete_mask.id + privacy_request.save(db) + identity_attribute = "email" identity_value = "test@ethyca.com" identity_kwargs = {identity_attribute: identity_value} @@ -255,9 +272,9 @@ async def test_saas_erasure_order_request_task_resume_from_error( merged_graph = saas_erasure_order_dataset_config.get_graph() graph = DatasetGraph(merged_graph) - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, - policy, + erasure_policy_complete_mask, # If we are doing an erasure request next, this needs to be accurate for DSR 3.0 graph, [saas_erasure_order_connection_config], {"email": "test@ethyca.com"}, @@ -284,15 +301,31 @@ async def test_saas_erasure_order_request_task_resume_from_error( # mock the mask_data function so we can force an exception on the "refunds_to_orders" # collection to simulate resuming from error - def side_effect(node, policy, privacy_request, rows, input_data): + def side_effect(node, policy, privacy_request, request_task, rows): if node.address.collection == "refunds_to_orders": raise Exception("Error executing refunds_to_orders task") + request_task.rows_masked = 1 + if request_task.id: + # DSR 3.0 needs to save this to the request task + session = Session.object_session(request_task) + request_task.save(session) return 1 mock_mask_data.side_effect = side_effect - with pytest.raises(Exception): - await graph_task.run_erasure( + if dsr_version == "use_dsr_2_0": + with pytest.raises(Exception): + erasure_runner_tester( + privacy_request, + erasure_policy_complete_mask, + graph, + [saas_erasure_order_connection_config], + identity_kwargs, + get_cached_data_for_erasures(privacy_request.id), + db, + ) + else: + erasure_runner_tester( privacy_request, erasure_policy_complete_mask, graph, @@ -303,11 +336,16 @@ def side_effect(node, policy, privacy_request, rows, input_data): ) # "fix" the refunds_to_orders collection and resume the erasure - mock_mask_data.side_effect = ( - lambda node, policy, privacy_request, rows, input_data: 1 - ) + def side_effect(node, policy, privacy_request, request_task, rows): + request_task.rows_masked = 1 + if request_task.id: + session = Session.object_session(request_task) + request_task.save(session) + return 1 + + mock_mask_data.side_effect = side_effect - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, erasure_policy_complete_mask, graph, @@ -326,24 +364,52 @@ def side_effect(node, policy, privacy_request, rows, input_data): f"{dataset_name}:refunds_to_orders": 1, } - assert [ - (log.collection_name, log.status.value) - for log in erasure_execution_logs(db, privacy_request) - ] == [ - ("products", "in_processing"), - ("products", "complete"), - ("orders_to_refunds", "in_processing"), - ("orders_to_refunds", "complete"), - ("refunds_to_orders", "in_processing"), - ("refunds_to_orders", "error"), - ("refunds_to_orders", "in_processing"), - ("refunds_to_orders", "complete"), - ("orders", "in_processing"), - ("orders", "complete"), - ("refunds", "in_processing"), - ("refunds", "complete"), - ("labels", "in_processing"), - ("labels", "complete"), - ], "Cached collections were not re-executed after resuming the privacy request from errored state" + if dsr_version == "use_dsr_2_0": + assert [ + (log.collection_name, log.status.value) + for log in erasure_execution_logs(db, privacy_request) + ] == [ + ("products", "in_processing"), + ("products", "complete"), + ("orders_to_refunds", "in_processing"), + ("orders_to_refunds", "complete"), + ("refunds_to_orders", "in_processing"), + ("refunds_to_orders", "error"), + ("refunds_to_orders", "in_processing"), + ("refunds_to_orders", "complete"), + ("orders", "in_processing"), + ("orders", "complete"), + ("refunds", "in_processing"), + ("refunds", "complete"), + ("labels", "in_processing"), + ("labels", "complete"), + ], "Cached collections were not re-executed after resuming the privacy request from errored state" + else: + ordered_logs = [ + (el.collection_name, el.status.value) + for el in db.query(ExecutionLog) + .filter( + ExecutionLog.privacy_request_id == privacy_request.id, + ExecutionLog.action_type == ActionType.erasure, + ) + .order_by(ExecutionLog.collection_name, ExecutionLog.created_at) + .all() + ] + assert ordered_logs == [ + ("labels", "in_processing"), + ("labels", "complete"), + ("orders", "in_processing"), + ("orders", "complete"), + ("orders_to_refunds", "in_processing"), + ("orders_to_refunds", "complete"), + ("products", "in_processing"), + ("products", "complete"), + ("refunds", "in_processing"), + ("refunds", "complete"), + ("refunds_to_orders", "in_processing"), + ("refunds_to_orders", "error"), + ("refunds_to_orders", "in_processing"), + ("refunds_to_orders", "complete"), + ] CONFIG.execution.masking_strict = temp_masking diff --git a/tests/ops/integration_tests/test_integration_generic_email.py b/tests/ops/integration_tests/test_integration_generic_email.py index dc89659595..2f51fa6a44 100644 --- a/tests/ops/integration_tests/test_integration_generic_email.py +++ b/tests/ops/integration_tests/test_integration_generic_email.py @@ -21,6 +21,10 @@ @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) @mock.patch( "fides.api.service.privacy_request.email_batch_service.requeue_privacy_requests_after_email_send", ) @@ -29,6 +33,8 @@ async def test_erasure_email( mock_mailgun_dispatcher: Mock, mock_requeue_privacy_requests: Mock, db, + dsr_version, + request, erasure_policy, generic_erasure_email_connection_config, run_privacy_request_task, @@ -40,6 +46,7 @@ async def test_erasure_email( Verify the privacy request is set to "awaiting email send" and that one email is sent when the send_email_batch job is executed manually """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 pr = get_privacy_request_results( db, @@ -91,10 +98,16 @@ async def test_erasure_email( "fides.api.service.privacy_request.email_batch_service.requeue_privacy_requests_after_email_send", ) @mock.patch("fides.api.service.messaging.message_dispatch_service._mailgun_dispatcher") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_erasure_email_no_messaging_config( mock_mailgun_dispatcher: Mock, mock_requeue_privacy_requests: Mock, db, + dsr_version, + request, erasure_policy, generic_erasure_email_connection_config, run_privacy_request_task, @@ -105,6 +118,7 @@ async def test_erasure_email_no_messaging_config( Verify the privacy request is set to "awaiting email send" and that the email fails to send because of the missing messaging config. """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 pr = get_privacy_request_results( db, @@ -134,8 +148,14 @@ async def test_erasure_email_no_messaging_config( @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_erasure_email_no_write_permissions( db, + dsr_version, + request, erasure_policy, generic_erasure_email_connection_config, run_privacy_request_task, @@ -145,6 +165,7 @@ async def test_erasure_email_no_write_permissions( Run an erasure privacy request with only a generic erasure email connector. Verify we don't send an email for a connector with read-only access. """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 generic_erasure_email_connection_config.update( db=db, @@ -172,8 +193,14 @@ async def test_erasure_email_no_write_permissions( @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_erasure_email_no_updates_needed( db, + dsr_version, + request, policy, generic_erasure_email_connection_config, run_privacy_request_task, @@ -184,6 +211,7 @@ async def test_erasure_email_no_updates_needed( Verify the privacy request is set to "complete" because this is an access request and no erasures are needed. """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 pr = get_privacy_request_results( db, @@ -210,10 +238,16 @@ async def test_erasure_email_no_updates_needed( "fides.api.service.privacy_request.email_batch_service.requeue_privacy_requests_after_email_send", ) @mock.patch("fides.api.service.messaging.message_dispatch_service._mailgun_dispatcher") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_erasure_email_disabled_connector( mock_mailgun_dispatcher: Mock, mock_requeue_privacy_requests: Mock, db, + dsr_version, + request, erasure_policy, generic_erasure_email_connection_config, run_privacy_request_task, @@ -225,6 +259,7 @@ async def test_erasure_email_disabled_connector( Verify the privacy request is set to "awaiting email send" and that one email is sent when the send_email_batch job is executed manually """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 generic_erasure_email_connection_config.update( db=db, @@ -255,10 +290,16 @@ async def test_erasure_email_disabled_connector( "fides.api.service.privacy_request.email_batch_service.requeue_privacy_requests_after_email_send", ) @mock.patch("fides.api.service.messaging.message_dispatch_service._mailgun_dispatcher") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_erasure_email_unsupported_identity( mock_mailgun_dispatcher: Mock, mock_requeue_privacy_requests: Mock, db, + dsr_version, + request, erasure_policy, generic_erasure_email_connection_config, run_privacy_request_task, @@ -269,6 +310,7 @@ async def test_erasure_email_unsupported_identity( Run an erasure privacy request with only a generic erasure email connector. Verify the privacy request is set to "complete" because the provided identities are not supported. """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 pr = get_privacy_request_results( db, diff --git a/tests/ops/integration_tests/test_manual_task.py b/tests/ops/integration_tests/test_manual_task.py deleted file mode 100644 index 6367476d55..0000000000 --- a/tests/ops/integration_tests/test_manual_task.py +++ /dev/null @@ -1,479 +0,0 @@ -import uuid - -import pytest - -from fides.api.common_exceptions import PrivacyRequestPaused -from fides.api.graph.config import CollectionAddress -from fides.api.models.policy import CurrentStep -from fides.api.models.privacy_request import ( - ExecutionLog, - ExecutionLogStatus, - PrivacyRequest, -) -from fides.api.task import graph_task -from fides.config import CONFIG - -from ..graph.graph_test_util import assert_rows_match -from ..task.traversal_data import postgres_and_manual_nodes - - -@pytest.mark.integration_postgres -@pytest.mark.integration -@pytest.mark.usefixtures("postgres_integration_db") -@pytest.mark.asyncio -async def test_postgres_with_manual_input_access_request_task( - db, - policy, - integration_postgres_config, - integration_manual_config, -) -> None: - """Run a privacy request with two manual nodes""" - privacy_request = PrivacyRequest( - id=f"test_postgres_access_request_task_{uuid.uuid4()}" - ) - - # ATTEMPT 1 - storage unit node will throw an exception. Waiting on manual input. - with pytest.raises(PrivacyRequestPaused): - await graph_task.run_access_request( - privacy_request, - policy, - postgres_and_manual_nodes("postgres_example", "manual_example"), - [integration_postgres_config, integration_manual_config], - {"email": "customer-1@example.com"}, - db, - ) - - paused_details = privacy_request.get_paused_collection_details() - assert paused_details.collection == CollectionAddress( - "manual_example", "storage_unit" - ) - assert paused_details.step == CurrentStep.access - assert len(paused_details.action_needed) == 1 - - assert paused_details.action_needed[0].locators == { - "email": ["customer-1@example.com"] - } - - assert paused_details.action_needed[0].get == ["box_id", "email"] - assert paused_details.action_needed[0].update is None - - # Mock user retrieving storage unit data by adding manual data to cache - privacy_request.cache_manual_access_input( - CollectionAddress.from_string("manual_example:storage_unit"), - [{"box_id": 5, "email": "customer-1@example.com"}], - ) - - # Attempt 2 - Filing cabinet node will throw an exception. Waiting on manual input. - with pytest.raises(PrivacyRequestPaused): - await graph_task.run_access_request( - privacy_request, - policy, - postgres_and_manual_nodes("postgres_example", "manual_example"), - [integration_postgres_config, integration_manual_config], - {"email": "customer-1@example.com"}, - db, - ) - - paused_details = privacy_request.get_paused_collection_details() - assert paused_details.collection == CollectionAddress( - "manual_example", "filing_cabinet" - ) - assert paused_details.step == CurrentStep.access - assert len(paused_details.action_needed) == 1 - assert paused_details.action_needed[0].locators == {"customer_id": [1]} - - assert paused_details.action_needed[0].get == [ - "id", - "authorized_user", - "customer_id", - "payment_card_id", - ] - assert paused_details.action_needed[0].update is None - - # Add manual filing cabinet data from the user - privacy_request.cache_manual_access_input( - CollectionAddress.from_string("manual_example:filing_cabinet"), - [{"id": 1, "authorized_user": "Jane Doe", "payment_card_id": "pay_bbb-bbb"}], - ) - - # Attempt 3 - All manual data has been retrieved. - v = await graph_task.run_access_request( - privacy_request, - policy, - postgres_and_manual_nodes("postgres_example", "manual_example"), - [integration_postgres_config, integration_manual_config], - {"email": "customer-1@example.com"}, - db, - ) - # Manual filing cabinet data returned - assert_rows_match( - v["manual_example:filing_cabinet"], - min_size=1, - keys=["id", "authorized_user", "payment_card_id"], - ) - - # Manual storage unit data returned - assert_rows_match( - v["manual_example:storage_unit"], - min_size=1, - keys=["box_id", "email"], - ) - - # One customer row returned - assert_rows_match( - v["postgres_example:customer"], - min_size=1, - keys=["id", "name", "email", "address_id"], - ) - - # Two payment card rows returned, one from customer_id input, other retrieved from a separate manual input - assert_rows_match( - v["postgres_example:payment_card"], - min_size=2, - keys=["id", "name", "ccn", "customer_id", "billing_address_id"], - ) - - assert_rows_match( - v["postgres_example:orders"], - min_size=3, - keys=["id", "customer_id", "shipping_address_id", "payment_card_id"], - ) - - assert_rows_match( - v["postgres_example:address"], - min_size=2, - keys=["city", "id", "state", "street", "zip"], - ) - - # Paused details removed from cache - paused_details = privacy_request.get_paused_collection_details() - assert paused_details is None - - execution_logs = db.query(ExecutionLog).filter_by( - privacy_request_id=privacy_request.id - ) - - # Customer node run once. - customer_logs = execution_logs.filter_by(collection_name="customer").order_by( - "created_at" - ) - assert [log.status for log in customer_logs] == [ - ExecutionLogStatus.in_processing, - ExecutionLogStatus.complete, - ] - assert customer_logs.count() == 2 - - # Storage unit node run twice. - storage_unit_logs = execution_logs.filter_by( - collection_name="storage_unit" - ).order_by("created_at") - assert storage_unit_logs.count() == 4 - assert [log.status for log in storage_unit_logs] == [ - ExecutionLogStatus.in_processing, - ExecutionLogStatus.paused, - ExecutionLogStatus.in_processing, - ExecutionLogStatus.complete, - ] - - # Order node run once - order_logs = execution_logs.filter_by(collection_name="orders").order_by( - "created_at" - ) - assert [log.status for log in order_logs] == [ - ExecutionLogStatus.in_processing, - ExecutionLogStatus.complete, - ] - assert order_logs.count() == 2 - - # Filing cabinet node run twice - filing_cabinet_logs = execution_logs.filter_by( - collection_name="filing_cabinet" - ).order_by("created_at") - assert filing_cabinet_logs.count() == 4 - assert [log.status for log in filing_cabinet_logs] == [ - ExecutionLogStatus.in_processing, - ExecutionLogStatus.paused, - ExecutionLogStatus.in_processing, - ExecutionLogStatus.complete, - ] - - # Payment card node run once - payment_logs = execution_logs.filter_by(collection_name="payment_card").order_by( - "created_at" - ) - assert [log.status for log in payment_logs] == [ - ExecutionLogStatus.in_processing, - ExecutionLogStatus.complete, - ] - - # Address logs run once - address_logs = execution_logs.filter_by(collection_name="address").order_by( - "created_at" - ) - assert [log.status for log in address_logs] == [ - ExecutionLogStatus.in_processing, - ExecutionLogStatus.complete, - ] - - -@pytest.mark.integration_postgres -@pytest.mark.integration -@pytest.mark.usefixtures("postgres_integration_db") -@pytest.mark.asyncio -async def test_no_manual_input_found( - policy, - db, - integration_postgres_config, - integration_manual_config, -) -> None: - """Assert manual node can be restarted with an empty list. There isn't necessarily manual data found.""" - privacy_request = PrivacyRequest( - id=f"test_postgres_access_request_task_{uuid.uuid4()}" - ) - - # ATTEMPT 1 - storage unit node will throw an exception. Waiting on manual input. - with pytest.raises(PrivacyRequestPaused): - await graph_task.run_access_request( - privacy_request, - policy, - postgres_and_manual_nodes("postgres_example", "manual_example"), - [integration_postgres_config, integration_manual_config], - {"email": "customer-1@example.com"}, - db, - ) - - paused_details = privacy_request.get_paused_collection_details() - assert paused_details.collection == CollectionAddress( - "manual_example", "storage_unit" - ) - assert paused_details.step == CurrentStep.access - - # Mock user retrieving storage unit data by adding manual data to cache, - # In this case, no data was found in the storage unit, so we pass in an empty list. - privacy_request.cache_manual_access_input( - CollectionAddress.from_string("manual_example:storage_unit"), - [], - ) - - # Attempt 2 - Filing cabinet node will throw an exception. Waiting on manual input. - with pytest.raises(PrivacyRequestPaused): - await graph_task.run_access_request( - privacy_request, - policy, - postgres_and_manual_nodes("postgres_example", "manual_example"), - [integration_postgres_config, integration_manual_config], - {"email": "customer-1@example.com"}, - db, - ) - - paused_details = privacy_request.get_paused_collection_details() - assert paused_details.collection == CollectionAddress( - "manual_example", "filing_cabinet" - ) - assert paused_details.step == CurrentStep.access - - # No filing cabinet input found - privacy_request.cache_manual_access_input( - CollectionAddress.from_string("manual_example:filing_cabinet"), - [], - ) - - # Attempt 3 - All manual data has been retrieved/attempted to be retrieved - v = await graph_task.run_access_request( - privacy_request, - policy, - postgres_and_manual_nodes("postgres_example", "manual_example"), - [integration_postgres_config, integration_manual_config], - {"email": "customer-1@example.com"}, - db, - ) - - # No filing cabinet data or storage unit data found - assert v["manual_example:filing_cabinet"] == [] - assert v["manual_example:storage_unit"] == [] - - # One customer row returned - assert_rows_match( - v["postgres_example:customer"], - min_size=1, - keys=["id", "name", "email", "address_id"], - ) - - # One payment card row returned - assert_rows_match( - v["postgres_example:payment_card"], - min_size=1, - keys=["id", "name", "ccn", "customer_id", "billing_address_id"], - ) - - # Paused node removed from cache - paused_details = privacy_request.get_paused_collection_details() - assert paused_details is None - - -@pytest.mark.integration_postgres -@pytest.mark.integration -@pytest.mark.asyncio -async def test_collections_with_manual_erasure_confirmation( - db, - erasure_policy, - integration_postgres_config, - integration_manual_config, - privacy_request, -) -> None: - """Run an erasure privacy request with two manual nodes""" - privacy_request.policy = erasure_policy - rule = erasure_policy.rules[0] - target = rule.targets[0] - target.data_category = "user" - - cached_data_for_erasures = { - "postgres_example:payment_card": [ - { - "id": "pay_aaa-aaa", - "name": "Example Card 1", - "ccn": 123456789, - "customer_id": 1, - "billing_address_id": 1, - }, - { - "id": "pay_bbb-bbb", - "name": "Example Card 2", - "ccn": 987654321, - "customer_id": 2, - "billing_address_id": 1, - }, - ], - "postgres_example:address": [ - { - "id": 1, - "street": "Example Street", - "city": "Exampletown", - "state": "NY", - "zip": "12345", - }, - { - "id": 2, - "street": "Example Lane", - "city": "Exampletown", - "state": "NY", - "zip": "12321", - }, - ], - "postgres_example:customer": [ - { - "id": 1, - "name": "John Customer", - "email": "customer-1@example.com", - "address_id": 1, - } - ], - "manual_example:filing_cabinet": [ - {"id": 1, "authorized_user": "Jane Doe", "payment_card_id": "pay_bbb-bbb"} - ], - "manual_example:storage_unit": [ - {"box_id": 5, "email": "customer-1@example.com"} - ], - "postgres_example:orders": [ - { - "id": "ord_aaa-aaa", - "customer_id": 1, - "shipping_address_id": 2, - "payment_card_id": "pay_aaa-aaa", - }, - { - "id": "ord_ccc-ccc", - "customer_id": 1, - "shipping_address_id": 1, - "payment_card_id": "pay_aaa-aaa", - }, - { - "id": "ord_ddd-ddd", - "customer_id": 1, - "shipping_address_id": 1, - "payment_card_id": "pay_bbb-bbb", - }, - ], - } - - # ATTEMPT 1 - erasure request will pause to wait for confirmation that data has been destroyed from - # the filing cabinet - with pytest.raises(PrivacyRequestPaused): - await graph_task.run_erasure( - privacy_request, - erasure_policy, - postgres_and_manual_nodes("postgres_example", "manual_example"), - [integration_postgres_config, integration_manual_config], - {"email": "customer-1@example.com"}, - cached_data_for_erasures, - db, - ) - - paused_details = privacy_request.get_paused_collection_details() - assert paused_details.collection == CollectionAddress( - "manual_example", "filing_cabinet" - ) - assert paused_details.step == CurrentStep.erasure - assert len(paused_details.action_needed) == 1 - - assert paused_details.action_needed[0].locators == {"id": 1} - - assert paused_details.action_needed[0].get is None - assert paused_details.action_needed[0].update == {"authorized_user": None} - - # Mock confirming from user that there was no data in the filing cabinet - privacy_request.cache_manual_erasure_count( - CollectionAddress.from_string("manual_example:filing_cabinet"), - 0, - ) - - # Attempt 2 - erasure request will pause, waiting for confirmation that the box in the storage unit is destroyed. - with pytest.raises(PrivacyRequestPaused): - await graph_task.run_erasure( - privacy_request, - erasure_policy, - postgres_and_manual_nodes("postgres_example", "manual_example"), - [integration_postgres_config, integration_manual_config], - {"email": "customer-1@example.com"}, - cached_data_for_erasures, - db, - ) - - paused_details = privacy_request.get_paused_collection_details() - assert paused_details.collection == CollectionAddress( - "manual_example", "storage_unit" - ) - assert paused_details.step == CurrentStep.erasure - assert len(paused_details.action_needed) == 1 - - assert paused_details.action_needed[0].locators == {"box_id": 5} - - assert paused_details.action_needed[0].get is None - assert paused_details.action_needed[0].update == {"email": None} - - # Mock confirming from user that storage unit erasure is complete - privacy_request.cache_manual_erasure_count( - CollectionAddress.from_string("manual_example:storage_unit"), 1 - ) - - # Attempt 3 - We've confirmed data has been removed for manual nodes so we can proceed with the rest of the erasure - v = await graph_task.run_erasure( - privacy_request, - erasure_policy, - postgres_and_manual_nodes("postgres_example", "manual_example"), - [integration_postgres_config, integration_manual_config], - {"email": "customer-1@example.com"}, - cached_data_for_erasures, - db, - ) - - assert v == { - "postgres_example:customer": 0, - "manual_example:storage_unit": 1, - "postgres_example:payment_card": 0, - "postgres_example:orders": 0, - "postgres_example:address": 0, - "manual_example:filing_cabinet": 0, - } - - assert privacy_request.get_paused_collection_details() is None diff --git a/tests/ops/integration_tests/test_mongo_task.py b/tests/ops/integration_tests/test_mongo_task.py index 2fd53fb928..4cf1d93dd2 100644 --- a/tests/ops/integration_tests/test_mongo_task.py +++ b/tests/ops/integration_tests/test_mongo_task.py @@ -1,8 +1,5 @@ import copy from datetime import datetime -from unittest import mock -from unittest.mock import Mock -from uuid import uuid4 import pytest from bson import ObjectId @@ -19,12 +16,12 @@ from fides.api.models.connectionconfig import ConnectionConfig from fides.api.models.datasetconfig import convert_dataset_to_graph from fides.api.models.policy import Policy -from fides.api.models.privacy_request import PrivacyRequest +from fides.api.models.privacy_request import RequestTask from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.filter_results import filter_data_categories from fides.api.task.graph_task import get_cached_data_for_erasures +from ...conftest import access_runner_tester, erasure_runner_tester from ..graph.graph_test_util import assert_rows_match, erasure_policy, field from ..task.traversal_data import ( combined_mongo_postgresql_graph, @@ -37,74 +34,96 @@ @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_combined_erasure_task( db, - mongo_inserts, postgres_inserts, integration_mongodb_config, integration_postgres_config, integration_mongodb_connector, + privacy_request_with_erasure_policy, + privacy_request, + mongo_inserts, + dsr_version, + request, ): """Includes examples of mongo nested and array erasures""" - policy = erasure_policy("A", "B") + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + policy = erasure_policy(db, "user.name", "user.contact") seed_email = postgres_inserts["customer"][0]["email"] - privacy_request = PrivacyRequest(id=f"test_sql_erasure_task_{uuid4()}") + privacy_request_with_erasure_policy.policy_id = policy.id + privacy_request_with_erasure_policy.save(db) + mongo_dataset, postgres_dataset = combined_mongo_postgresql_graph( integration_postgres_config, integration_mongodb_config ) field([postgres_dataset], "postgres_example", "address", "city").data_categories = [ - "A" + "user.name" ] field( [postgres_dataset], "postgres_example", "address", "state" - ).data_categories = ["B"] + ).data_categories = ["user.contact"] field([postgres_dataset], "postgres_example", "address", "zip").data_categories = [ - "C" + "user.email" ] field( [postgres_dataset], "postgres_example", "customer", "name" - ).data_categories = ["A"] - field([mongo_dataset], "mongo_test", "address", "city").data_categories = ["A"] - field([mongo_dataset], "mongo_test", "address", "state").data_categories = ["B"] - field([mongo_dataset], "mongo_test", "address", "zip").data_categories = ["C"] + ).data_categories = ["user.name"] + field([mongo_dataset], "mongo_test", "address", "city").data_categories = [ + "user.name" + ] + field([mongo_dataset], "mongo_test", "address", "state").data_categories = [ + "user.contact" + ] + field([mongo_dataset], "mongo_test", "address", "zip").data_categories = [ + "user.email" + ] field( [mongo_dataset], "mongo_test", "customer_details", "workplace_info", "position" - ).data_categories = ["A"] + ).data_categories = ["user.name"] field( [mongo_dataset], "mongo_test", "customer_details", "emergency_contacts", "phone" - ).data_categories = ["B"] + ).data_categories = ["user.contact"] field( [mongo_dataset], "mongo_test", "customer_details", "children" - ).data_categories = ["B"] + ).data_categories = ["user.contact"] field( [mongo_dataset], "mongo_test", "internal_customer_profile", "derived_interests" - ).data_categories = ["B"] - field([mongo_dataset], "mongo_test", "employee", "email").data_categories = ["B"] + ).data_categories = ["user.contact"] + field([mongo_dataset], "mongo_test", "employee", "email").data_categories = [ + "user.contact" + ] field( [mongo_dataset], "mongo_test", "customer_feedback", "customer_information", "phone", - ).data_categories = ["A"] + ).data_categories = ["user.name"] field( [mongo_dataset], "mongo_test", "conversations", "thread", "chat_name" - ).data_categories = ["B"] + ).data_categories = ["user.contact"] field( [mongo_dataset], "mongo_test", "flights", "passenger_information", "passenger_ids", - ).data_categories = ["A"] - field([mongo_dataset], "mongo_test", "aircraft", "planes").data_categories = ["A"] + ).data_categories = ["user.name"] + field([mongo_dataset], "mongo_test", "aircraft", "planes").data_categories = [ + "user.name" + ] graph = DatasetGraph(mongo_dataset, postgres_dataset) - await graph_task.run_access_request( - privacy_request, + access_runner_tester( + privacy_request_with_erasure_policy, policy, graph, [integration_mongodb_config, integration_postgres_config], @@ -112,13 +131,13 @@ async def test_combined_erasure_task( db, ) - x = await graph_task.run_erasure( - privacy_request, + x = erasure_runner_tester( + privacy_request_with_erasure_policy, policy, graph, [integration_mongodb_config, integration_postgres_config], {"email": seed_email}, - get_cached_data_for_erasures(privacy_request.id), + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), db, ) @@ -139,8 +158,7 @@ async def test_combined_erasure_task( "mongo_test:rewards": 0, } - privacy_request = PrivacyRequest(id=f"test_sql_erasure_task_{uuid4()}") - rerun_access = await graph_task.run_access_request( + rerun_access = access_runner_tester( privacy_request, policy, graph, @@ -249,34 +267,50 @@ async def test_combined_erasure_task( @pytest.mark.integration_mongodb @pytest.mark.integration @pytest.mark.asyncio -async def test_mongo_erasure_task(db, mongo_inserts, integration_mongodb_config): - policy = erasure_policy("A", "B") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +async def test_mongo_erasure_task( + db, + mongo_inserts, + integration_mongodb_config, + dsr_version, + request, + privacy_request_with_erasure_policy, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + policy = erasure_policy(db, "user.name", "user.contact") seed_email = mongo_inserts["customer"][0]["email"] - privacy_request = PrivacyRequest(id=f"test_sql_erasure_task_{uuid4()}") + privacy_request_with_erasure_policy.policy_id = policy.id + privacy_request_with_erasure_policy.save(db) dataset, graph = integration_db_mongo_graph( "mongo_test", integration_mongodb_config.key ) - field([dataset], "mongo_test", "address", "city").data_categories = ["A"] - field([dataset], "mongo_test", "address", "state").data_categories = ["B"] - field([dataset], "mongo_test", "address", "zip").data_categories = ["C"] - field([dataset], "mongo_test", "customer", "name").data_categories = ["A"] + field([dataset], "mongo_test", "address", "city").data_categories = ["user.name"] + field([dataset], "mongo_test", "address", "state").data_categories = [ + "user.contact" + ] + field([dataset], "mongo_test", "address", "zip").data_categories = ["user.email"] + field([dataset], "mongo_test", "customer", "name").data_categories = ["user.name"] - await graph_task.run_access_request( - privacy_request, + access_runner_tester( + privacy_request_with_erasure_policy, policy, graph, [integration_mongodb_config], {"email": seed_email}, db, ) - v = await graph_task.run_erasure( - privacy_request, + v = erasure_runner_tester( + privacy_request_with_erasure_policy, policy, graph, [integration_mongodb_config], {"email": seed_email}, - get_cached_data_for_erasures(privacy_request.id), + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), db, ) assert v == { @@ -290,12 +324,20 @@ async def test_mongo_erasure_task(db, mongo_inserts, integration_mongodb_config) @pytest.mark.integration_mongodb @pytest.mark.integration @pytest.mark.asyncio -async def test_dask_mongo_task( - db, integration_mongodb_config: ConnectionConfig +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +async def test_access_mongo_task( + db, + integration_mongodb_config: ConnectionConfig, + dsr_version, + privacy_request, + request, ) -> None: - privacy_request = PrivacyRequest(id=f"test_mongo_task_{uuid4()}") + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, empty_policy, integration_db_graph("mongo_test", integration_mongodb_config.key), @@ -330,9 +372,15 @@ async def test_dask_mongo_task( async def test_composite_key_erasure( db, integration_mongodb_config: ConnectionConfig, + mongo_inserts, + privacy_request, + privacy_request_with_erasure_policy, + use_dsr_3_0, ) -> None: - privacy_request = PrivacyRequest(id=f"test_mongo_task_{uuid4()}") - policy = erasure_policy("A") + policy = erasure_policy(db, "user.name") + privacy_request_with_erasure_policy.policy_id = policy.id + privacy_request_with_erasure_policy.save(db) + customer = Collection( name="customer", fields=[ @@ -361,7 +409,7 @@ async def test_composite_key_erasure( ScalarField( name="description", data_type_converter=StringTypeConverter(), - data_categories=["A"], + data_categories=["user.name"], ), ScalarField( name="customer_id", @@ -377,8 +425,8 @@ async def test_composite_key_erasure( connection_key=integration_mongodb_config.key, ) - access_request_data = await graph_task.run_access_request( - privacy_request, + access_request_data = access_runner_tester( + privacy_request_with_erasure_policy, policy, DatasetGraph(dataset), [integration_mongodb_config], @@ -393,13 +441,13 @@ async def test_composite_key_erasure( assert composite_pk_test["customer_id"] == 1 # erasure - erasure = await graph_task.run_erasure( - privacy_request, + erasure = erasure_runner_tester( + privacy_request_with_erasure_policy, policy, DatasetGraph(dataset), [integration_mongodb_config], {"email": "employee-1@example.com"}, - get_cached_data_for_erasures(privacy_request.id), + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), db, ) @@ -407,8 +455,8 @@ async def test_composite_key_erasure( # re-run access request. Description has been # nullified here. - privacy_request = PrivacyRequest(id=f"test_mongo_task_{uuid4()}") - access_request_data = await graph_task.run_access_request( + + access_request_data = access_runner_tester( privacy_request, policy, DatasetGraph(dataset), @@ -426,12 +474,18 @@ async def test_composite_key_erasure( async def test_access_erasure_type_conversion( db, integration_mongodb_config: ConnectionConfig, + privacy_request_with_erasure_policy, + use_dsr_3_0, ) -> None: """Retrieve data from the type_link table. This requires retrieving data from the employee foreign_id field, which is an object_id stored as a string, and converting it into an object_id to query against the type_link_test._id field.""" - privacy_request = PrivacyRequest(id=f"test_mongo_task_{uuid4()}") - policy = erasure_policy("A") + + policy = erasure_policy(db, "user.name") + + privacy_request_with_erasure_policy.policy_id = policy.id + privacy_request_with_erasure_policy.save(db) + employee = Collection( name="employee", fields=[ @@ -460,7 +514,7 @@ async def test_access_erasure_type_conversion( ScalarField( name="name", data_type_converter=StringTypeConverter(), - data_categories=["A"], + data_categories=["user.name"], ), ScalarField(name="key", data_type_converter=IntTypeConverter()), ], @@ -472,8 +526,8 @@ async def test_access_erasure_type_conversion( connection_key=integration_mongodb_config.key, ) - access_request_data = await graph_task.run_access_request( - privacy_request, + access_request_data = access_runner_tester( + privacy_request_with_erasure_policy, policy, DatasetGraph(dataset), [integration_mongodb_config], @@ -488,13 +542,13 @@ async def test_access_erasure_type_conversion( assert link["_id"] == ObjectId("000000000000000000000001") # erasure - erasure = await graph_task.run_erasure( - privacy_request, + erasure = erasure_runner_tester( + privacy_request_with_erasure_policy, policy, DatasetGraph(dataset), [integration_mongodb_config], {"email": "employee-1@example.com"}, - get_cached_data_for_erasures(privacy_request.id), + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), db, ) @@ -503,6 +557,10 @@ async def test_access_erasure_type_conversion( @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_object_querying_mongo( db, privacy_request, @@ -510,7 +568,10 @@ async def test_object_querying_mongo( policy, integration_mongodb_config, integration_postgres_config, + dsr_version, + request, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 postgres_config = copy.copy(integration_postgres_config) dataset_postgres = Dataset(**example_datasets[0]) @@ -521,7 +582,7 @@ async def test_object_querying_mongo( ) dataset_graph = DatasetGraph(*[graph, mongo_graph]) - access_request_results = await graph_task.run_access_request( + access_request_results = access_runner_tester( privacy_request, policy, dataset_graph, @@ -603,16 +664,19 @@ async def test_object_querying_mongo( @pytest.mark.integration @pytest.mark.asyncio async def test_get_cached_data_for_erasures( - integration_postgres_config, integration_mongodb_config, policy, db + integration_postgres_config, + integration_mongodb_config, + policy, + db, + use_dsr_2_0, + privacy_request, ) -> None: - privacy_request = PrivacyRequest(id=f"test_mongo_task_{uuid4()}") - mongo_dataset, postgres_dataset = combined_mongo_postgresql_graph( integration_postgres_config, integration_mongodb_config ) graph = DatasetGraph(mongo_dataset, postgres_dataset) - access_request_results = await graph_task.run_access_request( + access_request_results = access_runner_tester( privacy_request, policy, graph, @@ -646,6 +710,61 @@ async def test_get_cached_data_for_erasures( @pytest.mark.integration @pytest.mark.asyncio +async def test_get_saved_data_for_erasures_3_0( + integration_postgres_config, + integration_mongodb_config, + policy, + db, + use_dsr_3_0, + privacy_request, +) -> None: + mongo_dataset, postgres_dataset = combined_mongo_postgresql_graph( + integration_postgres_config, integration_mongodb_config + ) + graph = DatasetGraph(mongo_dataset, postgres_dataset) + + access_runner_tester( + privacy_request, + policy, + graph, + [integration_mongodb_config, integration_postgres_config], + {"email": "customer-1@example.com"}, + db, + ) + + conversations_task = privacy_request.access_tasks.filter( + RequestTask.collection_address == "mongo_test:conversations" + ).first() + + # Assert access task saved data in erasure format, that will be copied over to the erasure + # nodes of the same name + assert conversations_task.get_decoded_data_for_erasures()[0]["thread"] == [ + { + "comment": "com_0001", + "message": "hello, testing in-flight chat feature", + "chat_name": "John C", + "ccn": "123456789", + }, + "FIDESOPS_DO_NOT_MASK", + ] + + # The access request results are filtered on array data, because it was an entrypoint into the node. + assert conversations_task.get_decoded_access_data()[0]["thread"] == [ + { + "comment": "com_0001", + "message": "hello, testing in-flight chat feature", + "chat_name": "John C", + "ccn": "123456789", + } + ] + + +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_return_all_elements_config_access_request( db, privacy_request, @@ -654,6 +773,8 @@ async def test_return_all_elements_config_access_request( integration_mongodb_config, integration_postgres_config, integration_mongodb_connector, + dsr_version, + request, ): """Annotating array entrypoint field with return_all_elements=true means both the entire array is returned from the queried data and used to locate data in other collections @@ -661,6 +782,8 @@ async def test_return_all_elements_config_access_request( mongo_test:internal_customer_profile.customer_identifiers.derived_phone field and mongo_test:rewards.owner field have return_all_elements set to True """ + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + postgres_config = copy.copy(integration_postgres_config) dataset_postgres = Dataset(**example_datasets[0]) @@ -671,7 +794,7 @@ async def test_return_all_elements_config_access_request( ) dataset_graph = DatasetGraph(*[graph, mongo_graph]) - access_request_results = await graph_task.run_access_request( + access_request_results = access_runner_tester( privacy_request, policy, dataset_graph, @@ -699,6 +822,10 @@ async def test_return_all_elements_config_access_request( @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_return_all_elements_config_erasure( db, mongo_inserts, @@ -706,32 +833,38 @@ async def test_return_all_elements_config_erasure( integration_mongodb_config, integration_postgres_config, integration_mongodb_connector, + dsr_version, + request, + privacy_request, ): """Includes examples of mongo nested and array erasures""" - policy = erasure_policy("A", "B") + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + policy = erasure_policy(db, "user.name", "user.contact") + privacy_request.policy_id = policy.id + privacy_request.save(db) - privacy_request = PrivacyRequest(id=f"test_sql_erasure_task_{uuid4()}") mongo_dataset, postgres_dataset = combined_mongo_postgresql_graph( integration_postgres_config, integration_mongodb_config ) field( [mongo_dataset], "mongo_test", "rewards", "owner", "phone" - ).data_categories = ["A"] + ).data_categories = ["user.name"] field( [mongo_dataset], "mongo_test", "internal_customer_profile", "customer_identifiers", "derived_phone", - ).data_categories = ["B"] + ).data_categories = ["user.contact"] graph = DatasetGraph(mongo_dataset, postgres_dataset) seed_email = postgres_inserts["customer"][0]["email"] seed_phone = mongo_inserts["rewards"][0]["owner"][0]["phone"] - await graph_task.run_access_request( + access_runner_tester( privacy_request, policy, graph, @@ -740,7 +873,7 @@ async def test_return_all_elements_config_erasure( db, ) - x = await graph_task.run_erasure( + x = erasure_runner_tester( privacy_request, policy, graph, @@ -776,14 +909,23 @@ async def test_return_all_elements_config_erasure( @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_array_querying_mongo( db, privacy_request, + privacy_request_status_pending, example_datasets, policy, integration_mongodb_config, integration_postgres_config, + dsr_version, + request, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + postgres_config = copy.copy(integration_postgres_config) dataset_postgres = Dataset(**example_datasets[0]) @@ -794,7 +936,7 @@ async def test_array_querying_mongo( ) dataset_graph = DatasetGraph(*[graph, mongo_graph]) - access_request_results = await graph_task.run_access_request( + access_request_results = access_runner_tester( privacy_request, policy, dataset_graph, @@ -921,18 +1063,17 @@ async def test_array_querying_mongo( dataset_name="mongo_test", collection_name="conversations", status="complete" ) assert conversation_logs.count() == 1 - assert conversation_logs[0].fields_affected == [ - { - "path": "mongo_test:conversations:thread.chat_name", - "field_name": "thread.chat_name", - "data_categories": ["user.name"], - }, - { - "path": "mongo_test:conversations:thread.ccn", - "field_name": "thread.ccn", - "data_categories": ["user.financial.bank_account"], - }, - ] + assert { + "path": "mongo_test:conversations:thread.chat_name", + "field_name": "thread.chat_name", + "data_categories": ["user.name"], + } in conversation_logs[0].fields_affected + + assert { + "path": "mongo_test:conversations:thread.ccn", + "field_name": "thread.ccn", + "data_categories": ["user.financial.bank_account"], + } in conversation_logs[0].fields_affected # Integer field mongo_test:flights.plane used to locate only matching elem in mongo_test:aircraft:planes array field assert access_request_results["mongo_test:aircraft"][0]["planes"] == ["30005"] @@ -1040,9 +1181,8 @@ async def test_array_querying_mongo( ] # Run again with different email - privacy_request = PrivacyRequest(id=f"test_mongo_task_{uuid4()}") - access_request_results = await graph_task.run_access_request( - privacy_request, + access_request_results = access_runner_tester( + privacy_request_status_pending, policy, dataset_graph, [postgres_config, integration_mongodb_config], @@ -1081,7 +1221,7 @@ def connector(self, integration_mongodb_config): return get_connector(integration_mongodb_config) @pytest.fixture - def traversal_node(self, example_datasets, integration_mongodb_config): + def execution_node(self, example_datasets, integration_mongodb_config): dataset = Dataset(**example_datasets[1]) graph = convert_dataset_to_graph(dataset, integration_mongodb_config.key) customer_details_collection = None @@ -1091,18 +1231,15 @@ def traversal_node(self, example_datasets, integration_mongodb_config): break node = Node(graph, customer_details_collection) traversal_node = TraversalNode(node) - return traversal_node + return traversal_node.to_mock_execution_node() - @mock.patch("fides.api.graph.traversal.TraversalNode.incoming_edges") def test_retrieving_data( self, - mock_incoming_edges: Mock, privacy_request, - db, connector, - traversal_node, + execution_node, ): - mock_incoming_edges.return_value = { + execution_node.incoming_edges = { Edge( FieldAddress("fake_dataset", "fake_collection", "id"), FieldAddress("mongo_test", "customer_details", "customer_id"), @@ -1110,61 +1247,67 @@ def test_retrieving_data( } results = connector.retrieve_data( - traversal_node, Policy(), privacy_request, {"customer_id": [1]} + execution_node, + Policy(), + privacy_request, + RequestTask(), + {"customer_id": [1]}, ) assert results[0]["customer_id"] == 1 - @mock.patch("fides.api.graph.traversal.TraversalNode.incoming_edges") def test_retrieving_data_no_input( self, - mock_incoming_edges: Mock, privacy_request, - db, connector, - traversal_node, + execution_node, ): - mock_incoming_edges.return_value = { + execution_node.incoming_edges = { Edge( FieldAddress("fake_dataset", "fake_collection", "email"), FieldAddress("mongo_test", "customer_details", "customer_id"), ) } results = connector.retrieve_data( - traversal_node, Policy(), privacy_request, {"customer_id": []} + execution_node, + Policy(), + privacy_request, + RequestTask(), + {"customer_id": []}, ) assert results == [] - results = connector.retrieve_data(traversal_node, Policy(), privacy_request, {}) + results = connector.retrieve_data( + execution_node, Policy(), privacy_request, RequestTask(), {} + ) assert results == [] results = connector.retrieve_data( - traversal_node, Policy(), privacy_request, {"bad_key": ["test"]} + execution_node, + Policy(), + privacy_request, + RequestTask(), + {"bad_key": ["test"]}, ) assert results == [] results = connector.retrieve_data( - traversal_node, Policy(), privacy_request, {"email": [None]} + execution_node, Policy(), privacy_request, RequestTask(), {"email": [None]} ) assert results == [] results = connector.retrieve_data( - traversal_node, Policy(), privacy_request, {"email": None} + execution_node, Policy(), privacy_request, RequestTask(), {"email": None} ) assert results == [] - @mock.patch("fides.api.graph.traversal.TraversalNode.incoming_edges") def test_retrieving_data_input_not_in_table( self, - mock_incoming_edges: Mock, - db, privacy_request, - connection_config, - example_datasets, connector, - traversal_node, + execution_node, ): - mock_incoming_edges.return_value = { + execution_node.incoming_edges = { Edge( FieldAddress("fake_dataset", "fake_collection", "email"), FieldAddress("mongo_test", "customer_details", "customer_id"), @@ -1172,7 +1315,11 @@ def test_retrieving_data_input_not_in_table( } results = connector.retrieve_data( - traversal_node, Policy(), privacy_request, {"customer_id": [5]} + execution_node, + Policy(), + privacy_request, + RequestTask(), + {"customer_id": [5]}, ) assert results == [] diff --git a/tests/ops/integration_tests/test_privacy_request_logging.py b/tests/ops/integration_tests/test_privacy_request_logging.py index cb981ceb05..f32166795d 100644 --- a/tests/ops/integration_tests/test_privacy_request_logging.py +++ b/tests/ops/integration_tests/test_privacy_request_logging.py @@ -35,8 +35,14 @@ def mock_send(self) -> Generator: yield mock_send @pytest.mark.usefixtures("zendesk_runner") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) def test_access_error_logs( self, + dsr_version, + request, mock_send, api_client, url, @@ -45,6 +51,8 @@ def test_access_error_logs( loguru_caplog, provided_identity_value, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + response = api_client.post( url, headers=generate_auth_header(scopes=[PRIVACY_REQUEST_CREATE]), @@ -76,8 +84,14 @@ def test_access_error_logs( ) @pytest.mark.usefixtures("typeform_runner") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_erasure_error_logs( self, + dsr_version, + request, mock_send, api_client, url, @@ -87,6 +101,8 @@ async def test_erasure_error_logs( typeform_secrets, provided_identity_value, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + masking_strict = CONFIG.execution.masking_strict CONFIG.execution.masking_strict = False @@ -123,15 +139,30 @@ async def test_erasure_error_logs( CONFIG.execution.masking_strict = masking_strict @pytest.mark.usefixtures("klaviyo_runner") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_consent_error_logs( self, + dsr_version, + request, mock_send, klaviyo_runner, consent_policy, loguru_caplog, provided_identity_value, ): - with pytest.raises(ClientUnsuccessfulException): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + if dsr_version == "use_dsr_2_0": + with pytest.raises(ClientUnsuccessfulException): + await klaviyo_runner.new_consent_request( + consent_policy, + {"email": provided_identity_value}, + privacy_request_id="123", + ) + else: await klaviyo_runner.new_consent_request( consent_policy, {"email": provided_identity_value}, diff --git a/tests/ops/integration_tests/test_sql_task.py b/tests/ops/integration_tests/test_sql_task.py index 6506df4e48..d7a8c4d37d 100644 --- a/tests/ops/integration_tests/test_sql_task.py +++ b/tests/ops/integration_tests/test_sql_task.py @@ -2,7 +2,6 @@ from datetime import datetime from unittest import mock from unittest.mock import Mock -from uuid import uuid4 import pytest from fideslang import Dataset @@ -21,13 +20,13 @@ from fides.api.models.connectionconfig import ConnectionConfig from fides.api.models.datasetconfig import convert_dataset_to_graph from fides.api.models.policy import ActionType, Policy, Rule, RuleTarget -from fides.api.models.privacy_request import ExecutionLog, PrivacyRequest +from fides.api.models.privacy_request import ExecutionLog, RequestTask from fides.api.service.connectors import get_connector -from fides.api.task import graph_task from fides.api.task.filter_results import filter_data_categories from fides.api.task.graph_task import get_cached_data_for_erasures from fides.config import CONFIG +from ...conftest import access_runner_tester, erasure_runner_tester from ..graph.graph_test_util import ( assert_rows_match, erasure_policy, @@ -40,45 +39,50 @@ str_converter, ) -sample_postgres_configuration_policy = erasure_policy( - "system.operations", - "user.unique_id", - "user.sensor", - "user.contact.address.city", - "user.contact.email", - "user.contact.address.postal_code", - "user.contact.address.state", - "user.contact.address.street", - "user.financial.bank_account", - "user.financial", - "user.name", - "user", -) - @pytest.mark.integration_postgres @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_sql_erasure_ignores_collections_without_pk( - db, postgres_inserts, integration_postgres_config + db, + postgres_inserts, + integration_postgres_config, + privacy_request, + dsr_version, + request, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + seed_email = postgres_inserts["customer"][0]["email"] policy = erasure_policy( - "A", "B" + db, "user.contact", "system.operations" ) # makes an erasure policy with two data categories to match against + privacy_request.policy_id = policy.id + privacy_request.save(db) dataset = integration_db_dataset("postgres_example", "postgres_example") field([dataset], "postgres_example", "address", "id").primary_key = False - # set categories: A,B will be marked erasable, C will not - field([dataset], "postgres_example", "address", "city").data_categories = ["A"] - field([dataset], "postgres_example", "address", "state").data_categories = ["B"] - field([dataset], "postgres_example", "address", "zip").data_categories = ["C"] - field([dataset], "postgres_example", "customer", "name").data_categories = ["A"] + # set categories: user.contact, system.operations will be marked erasable, user.sensor will not + field([dataset], "postgres_example", "address", "city").data_categories = [ + "user.contact" + ] + field([dataset], "postgres_example", "address", "state").data_categories = [ + "system.operations" + ] + field([dataset], "postgres_example", "address", "zip").data_categories = [ + "user.sensor" + ] + field([dataset], "postgres_example", "customer", "name").data_categories = [ + "user.contact" + ] graph = DatasetGraph(dataset) - privacy_request = PrivacyRequest(id=str(uuid4())) - await graph_task.run_access_request( + access_runner_tester( privacy_request, policy, graph, @@ -86,7 +90,7 @@ async def test_sql_erasure_ignores_collections_without_pk( {"email": seed_email}, db, ) - v = await graph_task.run_erasure( + v = erasure_runner_tester( privacy_request, policy, graph, @@ -125,12 +129,23 @@ async def test_sql_erasure_ignores_collections_without_pk( @pytest.mark.integration_postgres @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_composite_key_erasure( db, integration_postgres_config: ConnectionConfig, + dsr_version, + request, + privacy_request, + privacy_request_with_erasure_policy, ) -> None: - privacy_request = PrivacyRequest(id=str(uuid4())) - policy = erasure_policy("A") + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + policy = erasure_policy(db, "user") + privacy_request_with_erasure_policy.policy_id = policy.id + privacy_request_with_erasure_policy.save(db) customer = Collection( name="customer", fields=[ @@ -159,7 +174,7 @@ async def test_composite_key_erasure( ScalarField( name="description", data_type_converter=StringTypeConverter(), - data_categories=["A"], + data_categories=["user"], ), ScalarField( name="customer_id", @@ -177,8 +192,8 @@ async def test_composite_key_erasure( connection_key=integration_postgres_config.key, ) - access_request_data = await graph_task.run_access_request( - privacy_request, + access_request_data = access_runner_tester( + privacy_request_with_erasure_policy, policy, DatasetGraph(dataset), [integration_postgres_config], @@ -192,13 +207,13 @@ async def test_composite_key_erasure( assert composite_pk_test["customer_id"] == 1 # erasure - erasure = await graph_task.run_erasure( - privacy_request, + erasure = erasure_runner_tester( + privacy_request_with_erasure_policy, policy, DatasetGraph(dataset), [integration_postgres_config], {"email": "employee-1@example.com"}, - get_cached_data_for_erasures(privacy_request.id), + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), db, ) @@ -209,8 +224,7 @@ async def test_composite_key_erasure( # re-run access request. Description has been # nullified here. - privacy_request = PrivacyRequest(id=str(uuid4())) - access_request_data = await graph_task.run_access_request( + access_request_data = access_runner_tester( privacy_request, policy, DatasetGraph(dataset), @@ -230,20 +244,44 @@ async def test_composite_key_erasure( @pytest.mark.integration_postgres @pytest.mark.integration @pytest.mark.asyncio -async def test_sql_erasure_task(db, postgres_inserts, integration_postgres_config): +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) +async def test_sql_erasure_task( + db, + postgres_inserts, + integration_postgres_config, + privacy_request, + dsr_version, + request, +): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + seed_email = postgres_inserts["customer"][0]["email"] - policy = erasure_policy("A", "B") + policy = erasure_policy(db, "user.name", "system") + privacy_request.policy_id = policy.id + privacy_request.save(db) + dataset = integration_db_dataset("postgres_example", "postgres_example") field([dataset], "postgres_example", "address", "id").primary_key = True - # set categories: A,B will be marked erasable, C will not - field([dataset], "postgres_example", "address", "city").data_categories = ["A"] - field([dataset], "postgres_example", "address", "state").data_categories = ["B"] - field([dataset], "postgres_example", "address", "zip").data_categories = ["C"] - field([dataset], "postgres_example", "customer", "name").data_categories = ["A"] + # set categories: user.name,system will be marked erasable, user.contact will not + # (data category labels are arbitrary) + field([dataset], "postgres_example", "address", "city").data_categories = [ + "user.name" + ] + field([dataset], "postgres_example", "address", "state").data_categories = [ + "system" + ] + field([dataset], "postgres_example", "address", "zip").data_categories = [ + "user.contact" + ] + field([dataset], "postgres_example", "customer", "name").data_categories = [ + "user.name" + ] graph = DatasetGraph(dataset) - privacy_request = PrivacyRequest(id=str(uuid4())) - await graph_task.run_access_request( + access_runner_tester( privacy_request, policy, graph, @@ -251,7 +289,7 @@ async def test_sql_erasure_task(db, postgres_inserts, integration_postgres_confi {"email": seed_email}, db, ) - v = await graph_task.run_erasure( + v = erasure_runner_tester( privacy_request, policy, graph, @@ -272,15 +310,23 @@ async def test_sql_erasure_task(db, postgres_inserts, integration_postgres_confi @pytest.mark.integration_postgres @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.timeout(5) +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_postgres_access_request_task( db, policy, integration_postgres_config, postgres_integration_db, + privacy_request, + dsr_version, + request, ) -> None: - privacy_request = PrivacyRequest(id=str(uuid4())) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, integration_db_graph("postgres_example"), @@ -359,27 +405,41 @@ async def test_postgres_access_request_task( @pytest.mark.integration_postgres @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_postgres_privacy_requests_against_non_default_schema( db, - policy, postgres_connection_config_with_schema, postgres_integration_db, erasure_policy, + request, + dsr_version, + privacy_request_with_erasure_policy, ) -> None: """Assert that the postgres connector can make access and erasure requests against the non-default (public) schema""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - privacy_request = PrivacyRequest(id=str(uuid4())) database_name = "postgres_backup" customer_email = "customer-500@example.com" dataset = integration_db_dataset( database_name, postgres_connection_config_with_schema.key ) + # Update data category on customer name - need to do this before the access runner, + # DSR 3.0 saves this upfront on the access and erasure Request Tasks + field([dataset], database_name, "customer", "name").data_categories = ["user.name"] + rule = erasure_policy.rules[0] + target = rule.targets[0] + target.data_category = "user" + target.save(db) + graph = DatasetGraph(dataset) - access_results = await graph_task.run_access_request( - privacy_request, - policy, + access_results = access_runner_tester( + privacy_request_with_erasure_policy, + erasure_policy, graph, [postgres_connection_config_with_schema], {"email": customer_email}, @@ -409,20 +469,13 @@ async def test_postgres_privacy_requests_against_non_default_schema( ], } - rule = erasure_policy.rules[0] - target = rule.targets[0] - target.data_category = "user" - target.save(db) - # Update data category on customer name - field([dataset], database_name, "customer", "name").data_categories = ["user.name"] - - erasure_results = await graph_task.run_erasure( - privacy_request, + erasure_results = erasure_runner_tester( + privacy_request_with_erasure_policy, erasure_policy, graph, [postgres_connection_config_with_schema], {"email": customer_email}, - get_cached_data_for_erasures(privacy_request.id), + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), db, ) @@ -444,15 +497,22 @@ async def test_postgres_privacy_requests_against_non_default_schema( @pytest.mark.integration_mssql @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_mssql_access_request_task( db, policy, connection_config_mssql, mssql_integration_db, + privacy_request, + dsr_version, + request, ) -> None: - privacy_request = PrivacyRequest(id=str(uuid4())) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, integration_db_graph("my_mssql_db_1"), @@ -531,15 +591,22 @@ async def test_mssql_access_request_task( @pytest.mark.integration_mysql @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_mysql_access_request_task( db, policy, connection_config_mysql, mysql_integration_db, + privacy_request, + dsr_version, + request, ) -> None: - privacy_request = PrivacyRequest(id=str(uuid4())) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, integration_db_graph("my_mysql_db_1"), @@ -618,15 +685,22 @@ async def test_mysql_access_request_task( @pytest.mark.integration_mariadb @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_mariadb_access_request_task( db, policy, connection_config_mariadb, mariadb_integration_db, + dsr_version, + request, + privacy_request, ) -> None: - privacy_request = PrivacyRequest(id=str(uuid4())) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, integration_db_graph("my_maria_db_1"), @@ -704,14 +778,21 @@ async def test_mariadb_access_request_task( @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_filter_on_data_categories( db, privacy_request, connection_config, example_datasets, policy, + dsr_version, + request, integration_postgres_config, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 postgres_config = copy.copy(integration_postgres_config) rule = Rule.create( @@ -739,7 +820,7 @@ async def test_filter_on_data_categories( graph = convert_dataset_to_graph(dataset, integration_postgres_config.key) dataset_graph = DatasetGraph(*[graph]) - access_request_results = await graph_task.run_access_request( + access_request_results = access_runner_tester( privacy_request, policy, dataset_graph, @@ -845,15 +926,24 @@ async def test_filter_on_data_categories( @pytest.mark.integration_postgres @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_access_erasure_type_conversion( db, integration_postgres_config: ConnectionConfig, + privacy_request_with_erasure_policy, + dsr_version, + request, ) -> None: """Retrieve data from the type_link table. This requires retrieving data from the employee id field, which is an int, and converting it into a string to query against the type_link_test.id field.""" - privacy_request = PrivacyRequest(id=str(uuid4())) - policy = erasure_policy("A") + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + policy = erasure_policy(db, "user.name") + privacy_request_with_erasure_policy.policy_id = policy.id + privacy_request_with_erasure_policy.save(db) employee = Collection( name="employee", fields=[ @@ -881,7 +971,7 @@ async def test_access_erasure_type_conversion( ScalarField( name="name", data_type_converter=StringTypeConverter(), - data_categories=["A"], + data_categories=["user.name"], ), ], ) @@ -892,8 +982,8 @@ async def test_access_erasure_type_conversion( connection_key=integration_postgres_config.key, ) - access_request_data = await graph_task.run_access_request( - privacy_request, + access_request_data = access_runner_tester( + privacy_request_with_erasure_policy, policy, DatasetGraph(dataset), [integration_postgres_config], @@ -908,13 +998,13 @@ async def test_access_erasure_type_conversion( assert link["id"] == "1" # erasure - erasure = await graph_task.run_erasure( - privacy_request, + erasure = erasure_runner_tester( + privacy_request_with_erasure_policy, policy, DatasetGraph(dataset), [integration_postgres_config], {"email": "employee-1@example.com"}, - get_cached_data_for_erasures(privacy_request.id), + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), db, ) @@ -932,24 +1022,22 @@ def connector(self, integration_postgres_config): return get_connector(integration_postgres_config) @pytest.fixture - def traversal_node(self, example_datasets, integration_postgres_config): + def execution_node(self, example_datasets, integration_postgres_config): dataset = Dataset(**example_datasets[0]) graph = convert_dataset_to_graph(dataset, integration_postgres_config.key) node = Node(graph, graph.collections[1]) # customer collection traversal_node = TraversalNode(node) - return traversal_node + return traversal_node.to_mock_execution_node() - @mock.patch("fides.api.graph.traversal.TraversalNode.incoming_edges") def test_retrieving_data( self, - mock_incoming_edges: Mock, privacy_request, db, connector, - traversal_node, + execution_node, postgres_integration_db, ): - mock_incoming_edges.return_value = { + execution_node.incoming_edges = { Edge( FieldAddress("fake_dataset", "fake_collection", "email"), FieldAddress("postgres_example_test_dataset", "customer", "email"), @@ -957,9 +1045,10 @@ def test_retrieving_data( } results = connector.retrieve_data( - traversal_node, + execution_node, Policy(), privacy_request, + RequestTask(), {"email": ["customer-1@example.com"]}, ) assert len(results) is 1 @@ -973,16 +1062,14 @@ def test_retrieving_data( } ] - @mock.patch("fides.api.graph.traversal.TraversalNode.incoming_edges") def test_retrieving_data_no_input( self, - mock_incoming_edges: Mock, privacy_request, db, connector, - traversal_node, + execution_node, ): - mock_incoming_edges.return_value = { + execution_node.incoming_edges = { Edge( FieldAddress("fake_dataset", "fake_collection", "email"), FieldAddress("postgres_example_test_dataset", "customer", "email"), @@ -990,47 +1077,50 @@ def test_retrieving_data_no_input( } assert [] == connector.retrieve_data( - traversal_node, Policy(), privacy_request, {"email": []} + execution_node, Policy(), privacy_request, RequestTask(), {"email": []} ) assert [] == connector.retrieve_data( - traversal_node, Policy(), privacy_request, {} + execution_node, Policy(), privacy_request, RequestTask(), {} ) assert [] == connector.retrieve_data( - traversal_node, Policy(), privacy_request, {"bad_key": ["test"]} + execution_node, + Policy(), + privacy_request, + RequestTask(), + {"bad_key": ["test"]}, ) assert [] == connector.retrieve_data( - traversal_node, Policy(), privacy_request, {"email": [None]} + execution_node, Policy(), privacy_request, RequestTask(), {"email": [None]} ) assert [] == connector.retrieve_data( - traversal_node, Policy(), privacy_request, {"email": None} + execution_node, Policy(), privacy_request, RequestTask(), {"email": None} ) - @mock.patch("fides.api.graph.traversal.TraversalNode.incoming_edges") def test_retrieving_data_input_not_in_table( self, - mock_incoming_edges: Mock, db, privacy_request, connection_config, example_datasets, connector, - traversal_node, + execution_node, postgres_integration_db, ): - mock_incoming_edges.return_value = { + execution_node.incoming_edges = { Edge( FieldAddress("fake_dataset", "fake_collection", "email"), FieldAddress("postgres_example_test_dataset", "customer", "email"), ) } results = connector.retrieve_data( - traversal_node, + execution_node, Policy(), privacy_request, + RequestTask(), {"email": ["customer_not_in_dataset@example.com"]}, ) assert results == [] @@ -1041,6 +1131,10 @@ def test_retrieving_data_input_not_in_table( class TestRetryIntegration: @mock.patch("fides.api.service.connectors.sql_connector.SQLConnector.retrieve_data") @pytest.mark.asyncio + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) async def test_retry_access_request( self, mock_retrieve, @@ -1048,9 +1142,26 @@ async def test_retry_access_request( privacy_request, connection_config, example_datasets, - policy, integration_postgres_config, + dsr_version, + request, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + sample_postgres_configuration_policy = erasure_policy( + db, + "system.operations", + "user.unique_id", + "user.sensor", + "user.contact.address.city", + "user.contact.email", + "user.contact.address.postal_code", + "user.contact.address.state", + "user.contact.address.street", + "user.financial.bank_account", + "user.name", + ) + CONFIG.execution.task_retry_count = 1 CONFIG.execution.task_retry_delay = 0.1 CONFIG.execution.task_retry_backoff = 0.01 @@ -1063,8 +1174,40 @@ async def test_retry_access_request( mock_retrieve.side_effect = Exception("Insufficient data") # Call run_access_request with an email that isn't in the database - with pytest.raises(Exception) as exc: - await graph_task.run_access_request( + + if dsr_version == "use_dsr_2_0": + with pytest.raises(Exception) as exc: + # DSR 2.0 will raise an exception when the first node is hit, + # stopping all other nodes from running + access_runner_tester( + privacy_request, + sample_postgres_configuration_policy, + dataset_graph, + [integration_postgres_config], + {"email": "customer-5@example.com"}, + db, + ) + execution_logs = db.query(ExecutionLog).filter_by( + privacy_request_id=privacy_request.id + ) + assert 3 == execution_logs.count() + + # Execution starts with the employee collection, retries once on failure, and then errors + assert [ + ( + CollectionAddress(log.dataset_name, log.collection_name).value, + log.status.value, + ) + for log in execution_logs.order_by("created_at") + ] == [ + ("postgres_example_test_dataset:employee", "in_processing"), + ("postgres_example_test_dataset:employee", "retrying"), + ("postgres_example_test_dataset:employee", "error"), + ] + else: + # DSR 3.0 will run the nodes that can run, an exception on one node + # will not necessarily stop all nodes from running + access_runner_tester( privacy_request, sample_postgres_configuration_policy, dataset_graph, @@ -1072,27 +1215,57 @@ async def test_retry_access_request( {"email": "customer-5@example.com"}, db, ) - - execution_logs = db.query(ExecutionLog).filter_by( - privacy_request_id=privacy_request.id - ) - - assert 3 == execution_logs.count() - - # Execution starts with the employee collection, retries once on failure, and then errors - assert [ - ( - CollectionAddress(log.dataset_name, log.collection_name).value, - log.status.value, + execution_logs = db.query(ExecutionLog).filter_by( + privacy_request_id=privacy_request.id ) - for log in execution_logs.order_by("created_at") - ] == [ - ("postgres_example_test_dataset:employee", "in_processing"), - ("postgres_example_test_dataset:employee", "retrying"), - ("postgres_example_test_dataset:employee", "error"), - ] + assert 12 == execution_logs.count() + + # All four nodes directly downstream of the root node attempt to process, + # and nothing further processes downstream + assert [ + ( + CollectionAddress(log.dataset_name, log.collection_name).value, + log.status.value, + ) + for log in execution_logs.order_by( + ExecutionLog.collection_name, ExecutionLog.created_at + ) + ] == [ + ("postgres_example_test_dataset:customer", "in_processing"), + ("postgres_example_test_dataset:customer", "retrying"), + ("postgres_example_test_dataset:customer", "error"), + ("postgres_example_test_dataset:employee", "in_processing"), + ("postgres_example_test_dataset:employee", "retrying"), + ("postgres_example_test_dataset:employee", "error"), + ("postgres_example_test_dataset:report", "in_processing"), + ("postgres_example_test_dataset:report", "retrying"), + ("postgres_example_test_dataset:report", "error"), + ("postgres_example_test_dataset:visit", "in_processing"), + ("postgres_example_test_dataset:visit", "retrying"), + ("postgres_example_test_dataset:visit", "error"), + ] + # Downstream request tasks were marked as error + assert [rt.status.value for rt in privacy_request.access_tasks] == [ + "complete", + "error", + "error", + "error", + "error", + "error", + "error", + "error", + "error", + "error", + "error", + "error", + "error", + ] @mock.patch("fides.api.service.connectors.sql_connector.SQLConnector.mask_data") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) @pytest.mark.asyncio async def test_retry_erasure( self, @@ -1103,7 +1276,25 @@ async def test_retry_erasure( example_datasets, policy, integration_postgres_config, + dsr_version, + request, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + + sample_postgres_configuration_policy = erasure_policy( + db, + "system.operations", + "user.unique_id", + "user.sensor", + "user.contact.address.city", + "user.contact.email", + "user.contact.address.postal_code", + "user.contact.address.state", + "user.contact.address.street", + "user.financial.bank_account", + "user.name", + ) + CONFIG.execution.task_retry_count = 2 CONFIG.execution.task_retry_delay = 0.1 CONFIG.execution.task_retry_backoff = 0.01 @@ -1115,64 +1306,133 @@ async def test_retry_erasure( # Mock errors with masking data mock_mask.side_effect = Exception("Insufficient data") + access_runner_tester( + privacy_request, + sample_postgres_configuration_policy, + dataset_graph, + [integration_postgres_config], + {"email": "customer-5@example.com"}, + db, + ) + # Call run_erasure with an email that isn't in the database - with pytest.raises(Exception): - await graph_task.run_erasure( + if dsr_version == "use_dsr_2_0": + with pytest.raises(Exception): + erasure_runner_tester( + privacy_request, + sample_postgres_configuration_policy, + dataset_graph, + [integration_postgres_config], + {"email": "customer-5@example.com"}, + { + "postgres_example_test_dataset:employee": [], + "postgres_example_test_dataset:visit": [], + "postgres_example_test_dataset:customer": [], + "postgres_example_test_dataset:report": [], + "postgres_example_test_dataset:orders": [], + "postgres_example_test_dataset:payment_card": [], + "postgres_example_test_dataset:service_request": [], + "postgres_example_test_dataset:login": [], + "postgres_example_test_dataset:address": [], + "postgres_example_test_dataset:order_item": [], + "postgres_example_test_dataset:product": [], + }, + db, + ) + execution_logs = db.query(ExecutionLog).filter_by( + privacy_request_id=privacy_request.id + ) + + # DSR 2.0 raises an exception on the first node hit + assert 4 == execution_logs.count() + + # Execution starts with the address collection, retries twice on failure, and then errors + assert [ + ( + CollectionAddress(log.dataset_name, log.collection_name).value, + log.status.value, + ) + for log in execution_logs.order_by("created_at") + ] == [ + ("postgres_example_test_dataset:address", "in_processing"), + ("postgres_example_test_dataset:address", "retrying"), + ("postgres_example_test_dataset:address", "retrying"), + ("postgres_example_test_dataset:address", "error"), + ] + else: + # DSR 3.0 does not raise an exception on the first node hit. + # Every node has a chance to process + erasure_runner_tester( privacy_request, sample_postgres_configuration_policy, dataset_graph, [integration_postgres_config], {"email": "customer-5@example.com"}, - { - "postgres_example_test_dataset:employee": [], - "postgres_example_test_dataset:visit": [], - "postgres_example_test_dataset:customer": [], - "postgres_example_test_dataset:report": [], - "postgres_example_test_dataset:orders": [], - "postgres_example_test_dataset:payment_card": [], - "postgres_example_test_dataset:service_request": [], - "postgres_example_test_dataset:login": [], - "postgres_example_test_dataset:address": [], - "postgres_example_test_dataset:order_item": [], - "postgres_example_test_dataset:product": [], - }, + {}, db, ) + execution_logs = db.query(ExecutionLog).filter_by( + privacy_request_id=privacy_request.id, action_type=ActionType.erasure + ) + assert 40 == execution_logs.count() - execution_logs = db.query(ExecutionLog).filter_by( - privacy_request_id=privacy_request.id - ) - - assert 4 == execution_logs.count() + # These nodes were able to complete because they didn't have a PK - nothing to erase + visit_logs = execution_logs.filter_by(collection_name="visit") + assert {"in_processing", "complete"} == { + el.status.value for el in visit_logs + } - # Execution starts with the address collection, retries twice on failure, and then errors - assert [ - ( - CollectionAddress(log.dataset_name, log.collection_name).value, - log.status.value, + order_item_logs = execution_logs.filter_by(collection_name="order_item") + assert {"in_processing", "complete"} == { + el.status.value for el in order_item_logs + } + # Address log mask data couldn't run, attempted to retry twice per configuration + address_logs = execution_logs.filter_by(collection_name="address").order_by( + ExecutionLog.created_at ) - for log in execution_logs.order_by("created_at") - ] == [ - ("postgres_example_test_dataset:address", "in_processing"), - ("postgres_example_test_dataset:address", "retrying"), - ("postgres_example_test_dataset:address", "retrying"), - ("postgres_example_test_dataset:address", "error"), - ] + assert ["in_processing", "retrying", "retrying", "error"] == [ + el.status.value for el in address_logs + ] + + # Downstream request tasks were marked as error. Some tasks completed because there is no PK + # on their collection and we can't erase + assert {rt.status.value for rt in privacy_request.erasure_tasks} == { + "complete", + "error", + "error", + "error", + "complete", + "error", + "error", + "error", + "error", + "error", + "complete", + "error", + "error", + } @pytest.mark.integration_timescale @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_timescale_access_request_task( db, policy, timescale_connection_config, timescale_integration_db, + privacy_request, + dsr_version, + request, ) -> None: database_name = "my_timescale_db_1" - privacy_request = PrivacyRequest(id=str(uuid4())) + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 - v = await graph_task.run_access_request( + v = access_runner_tester( privacy_request, policy, integration_db_graph(database_name), @@ -1255,19 +1515,27 @@ async def test_timescale_access_request_task( @pytest.mark.integration_timescale @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_timescale_erasure_request_task( db, erasure_policy, timescale_connection_config, timescale_integration_db, + privacy_request_with_erasure_policy, + dsr_version, + request, ) -> None: + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + rule = erasure_policy.rules[0] target = rule.targets[0] target.data_category = "user" target.save(db) database_name = "my_timescale_db_1" - privacy_request = PrivacyRequest(id=str(uuid4())) dataset = integration_db_dataset(database_name, timescale_connection_config.key) @@ -1278,76 +1546,22 @@ async def test_timescale_erasure_request_task( graph = DatasetGraph(dataset) - access_results = { # To avoid running a separate access request, just feed in what the access results would be - f"{database_name}:payment_card": [ - { - "id": "pay_aaa-aaa", - "name": "Example Card 1", - "ccn": 123456789, - "customer_id": 1, - "billing_address_id": 1, - }, - { - "id": "pay_bbb-bbb", - "name": "Example Card 2", - "ccn": 987654321, - "customer_id": 2, - "billing_address_id": 1, - }, - ], - f"{database_name}:customer": [ - { - "id": 1, - "name": "John Customer", - "email": "customer-1@example.com", - "address_id": 1, - } - ], - f"{database_name}:address": [ - { - "id": 1, - "street": "Example Street", - "city": "Exampletown", - "state": "NY", - "zip": "12345", - }, - { - "id": 2, - "street": "Example Lane", - "city": "Exampletown", - "state": "NY", - "zip": "12321", - }, - ], - f"{database_name}:orders": [ - { - "id": "ord_aaa-aaa", - "customer_id": 1, - "shipping_address_id": 2, - "payment_card_id": "pay_aaa-aaa", - }, - { - "id": "ord_ccc-ccc", - "customer_id": 1, - "shipping_address_id": 1, - "payment_card_id": "pay_aaa-aaa", - }, - { - "id": "ord_ddd-ddd", - "customer_id": 1, - "shipping_address_id": 1, - "payment_card_id": "pay_bbb-bbb", - }, - ], - } + v = access_runner_tester( + privacy_request_with_erasure_policy, + erasure_policy, + graph, + [timescale_connection_config], + {"email": "customer-1@example.com"}, + db, + ) - v = await graph_task.run_erasure( - privacy_request, + v = erasure_runner_tester( + privacy_request_with_erasure_policy, erasure_policy, graph, [timescale_connection_config], {"email": "customer-1@example.com"}, - access_results, + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), db, ) assert v == { @@ -1388,11 +1602,22 @@ async def test_timescale_erasure_request_task( @pytest.mark.integration_timescale @pytest.mark.integration @pytest.mark.asyncio +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) async def test_timescale_query_and_mask_hypertable( - db, policy, erasure_policy, timescale_connection_config, timescale_integration_db + db, + erasure_policy, + timescale_connection_config, + timescale_integration_db, + privacy_request_with_erasure_policy, + dsr_version, + request, ) -> None: + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + database_name = "my_timescale_db_1" - privacy_request = PrivacyRequest(id=str(uuid4())) dataset = integration_db_dataset(database_name, timescale_connection_config.key) # For this test, add a new collection to our standard dataset corresponding to the @@ -1411,10 +1636,18 @@ async def test_timescale_query_and_mask_hypertable( dataset.collections.append(onsite_personnel_collection) graph = DatasetGraph(dataset) + rule = erasure_policy.rules[0] + target = rule.targets[0] + target.data_category = "user" + target.save(db) + # Update data category on responsible field + field( + [dataset], database_name, "onsite_personnel", "responsible" + ).data_categories = ["user.contact.email"] - access_results = await graph_task.run_access_request( - privacy_request, - policy, + access_results = access_runner_tester( + privacy_request_with_erasure_policy, + erasure_policy, graph, [timescale_connection_config], {"email": "employee-1@example.com"}, @@ -1429,23 +1662,14 @@ async def test_timescale_query_and_mask_hypertable( {"responsible": "employee-1@example.com", "time": datetime(2022, 1, 5, 9, 0)}, ] - rule = erasure_policy.rules[0] - target = rule.targets[0] - target.data_category = "user" - target.save(db) - # Update data category on responsible field - field( - [dataset], database_name, "onsite_personnel", "responsible" - ).data_categories = ["user.contact.email"] - # Run an erasure on the hypertable targeting the responsible field - v = await graph_task.run_erasure( - privacy_request, + v = erasure_runner_tester( + privacy_request_with_erasure_policy, erasure_policy, graph, [timescale_connection_config], {"email": "employee-1@example.com"}, - get_cached_data_for_erasures(privacy_request.id), + get_cached_data_for_erasures(privacy_request_with_erasure_policy.id), db, ) diff --git a/tests/ops/models/test_privacy_request.py b/tests/ops/models/test_privacy_request.py index 9ce9747ce7..eed9decff2 100644 --- a/tests/ops/models/test_privacy_request.py +++ b/tests/ops/models/test_privacy_request.py @@ -19,6 +19,7 @@ from fides.api.models.privacy_request import ( CheckpointActionRequired, ConsentRequest, + ManualAction, PrivacyRequest, PrivacyRequestError, PrivacyRequestNotifications, @@ -26,10 +27,14 @@ ProvidedIdentity, can_run_checkpoint, ) +from fides.api.schemas.policy import ActionType from fides.api.schemas.privacy_request import CustomPrivacyRequestField from fides.api.schemas.redis_cache import Identity, LabeledIdentity -from fides.api.service.connectors.manual_connector import ManualAction -from fides.api.util.cache import FidesopsRedis, get_identity_cache_key +from fides.api.util.cache import ( + FidesopsRedis, + cache_task_tracking_key, + get_identity_cache_key, +) from fides.api.util.constants import API_DATE_FORMAT from fides.config import CONFIG @@ -488,65 +493,16 @@ def test_privacy_request_unpause(self, privacy_request): assert privacy_request.get_paused_collection_details() is None -class TestCacheManualInput: - def test_cache_manual_access_input(self, privacy_request): - manual_data = [{"id": 1, "name": "Jane"}, {"id": 2, "name": "Hank"}] - - privacy_request.cache_manual_access_input(paused_location, manual_data) - assert ( - privacy_request.get_manual_access_input( - paused_location, - ) - == manual_data - ) - - def test_cache_empty_manual_input(self, privacy_request): - manual_data = [] - privacy_request.cache_manual_access_input(paused_location, manual_data) - - assert ( - privacy_request.get_manual_access_input( - paused_location, - ) - == [] - ) - - def test_no_manual_data_in_cache(self, privacy_request): - assert ( - privacy_request.get_manual_access_input( - paused_location, - ) - is None - ) - - -class TestCacheManualErasureCount: - def test_cache_manual_erasure_count(self, privacy_request): - privacy_request.cache_manual_erasure_count(paused_location, 5) - - cached_data = privacy_request.get_manual_erasure_count(paused_location) - assert cached_data == 5 - - def test_no_erasure_data_cached(self, privacy_request): - cached_data = privacy_request.get_manual_erasure_count(paused_location) - assert cached_data is None - - def test_zero_cached(self, privacy_request): - privacy_request.cache_manual_erasure_count(paused_location, 0) - cached_data = privacy_request.get_manual_erasure_count(paused_location) - assert cached_data == 0 - - class TestPrivacyRequestCacheFailedStep: - def test_cache_failed_step_and_collection(self, privacy_request): - privacy_request.cache_failed_checkpoint_details( - step=CurrentStep.erasure, collection=paused_location - ) + def test_cache_failed_step(self, privacy_request): + privacy_request.cache_failed_checkpoint_details(step=CurrentStep.erasure) cached_data = privacy_request.get_failed_checkpoint_details() assert cached_data.step == CurrentStep.erasure - assert cached_data.collection == paused_location - assert cached_data.action_needed is None + assert cached_data.collection is None # This is deprecated + assert ( + cached_data.action_needed is None + ) # This isn't applicable for failed details def test_cache_null_step_and_location(self, privacy_request): privacy_request.cache_failed_checkpoint_details() @@ -1228,3 +1184,19 @@ def test_persist_custom_identities(self, db, privacy_request): customer_id=LabeledIdentity(label="Custom ID", value=123), account_id=LabeledIdentity(label="Account ID", value="456"), ) + + +class TestGetCeleryTaskRequestTaskIds: + def test_get_celery_task_request_task_ids(self, privacy_request, request_task): + """Not all request tasks have celery task ids in this test -""" + + assert privacy_request.get_request_task_celery_task_ids() == [] + + cache_task_tracking_key(request_task.id, "test_celery_task_key") + root_task = privacy_request.get_root_task_by_action(ActionType.access) + cache_task_tracking_key(root_task.id, "test_root_task_celery_key") + + assert set(privacy_request.get_request_task_celery_task_ids()) == { + "test_celery_task_key", + "test_root_task_celery_key", + } diff --git a/tests/ops/models/test_request_task.py b/tests/ops/models/test_request_task.py new file mode 100644 index 0000000000..14073678a7 --- /dev/null +++ b/tests/ops/models/test_request_task.py @@ -0,0 +1,328 @@ +import json +from unittest import mock + +import pytest + +from fides.api.graph.config import ( + ROOT_COLLECTION_ADDRESS, + TERMINATOR_ADDRESS, + CollectionAddress, +) +from fides.api.models.privacy_request import ExecutionLogStatus, RequestTask +from fides.api.schemas.policy import ActionType +from fides.api.util.cache import ( + CustomJSONEncoder, + FidesopsRedis, + cache_task_tracking_key, + get_cache, +) + + +class TestRequestTask: + def test_basic_attributes( + self, db, request_task, privacy_request, erasure_request_task + ): + assert privacy_request.request_tasks.count() == 6 + assert privacy_request.access_tasks.count() == 3 + assert privacy_request.erasure_tasks.count() == 3 + assert privacy_request.consent_tasks.count() == 0 + + assert erasure_request_task.privacy_request_id == privacy_request.id + assert request_task.privacy_request_id == privacy_request.id + assert request_task.privacy_request == privacy_request + assert request_task.action_type == ActionType.access + + assert request_task.get_tasks_with_same_action_type( + db, "__ROOT__:__ROOT__" + ).all() == [privacy_request.get_root_task_by_action(ActionType.access)] + assert erasure_request_task.get_tasks_with_same_action_type( + db, "__ROOT__:__ROOT__" + ).all() == [privacy_request.get_root_task_by_action(ActionType.erasure)] + + @pytest.mark.usefixtures("request_task", "erasure_request_task") + def test_get_tasks_by_action(self, privacy_request): + assert ( + privacy_request.get_tasks_by_action(ActionType.access).all() + == privacy_request.access_tasks.all() + ) + assert ( + privacy_request.get_tasks_by_action(ActionType.erasure).all() + == privacy_request.erasure_tasks.all() + ) + assert privacy_request.get_tasks_by_action(ActionType.consent).all() == [] + + with pytest.raises(Exception): + assert privacy_request.get_tasks_by_action(ActionType.update).all() == [] + + @pytest.mark.usefixtures("request_task", "erasure_request_task") + def test_get_root_task_by_action(self, privacy_request): + task = privacy_request.get_root_task_by_action(ActionType.access) + assert task.collection_address == "__ROOT__:__ROOT__" + assert task.action_type == ActionType.access + assert task.is_root_task + assert not task.is_terminator_task + + erasure_task = privacy_request.get_root_task_by_action(ActionType.erasure) + assert erasure_task.collection_address == "__ROOT__:__ROOT__" + assert erasure_task.action_type == ActionType.erasure + + with pytest.raises(Exception): + privacy_request.get_root_task_by_action(ActionType.consent) + + @pytest.mark.usefixtures("request_task", "erasure_request_task") + def test_get_terminate_task_by_action(self, privacy_request): + task = privacy_request.get_terminate_task_by_action(ActionType.access) + assert task.collection_address == "__TERMINATE__:__TERMINATE__" + assert task.action_type == ActionType.access + assert not task.is_root_task + assert task.is_terminator_task + + erasure_task = privacy_request.get_terminate_task_by_action(ActionType.erasure) + assert erasure_task.collection_address == "__TERMINATE__:__TERMINATE__" + assert erasure_task.action_type == ActionType.erasure + + with pytest.raises(Exception): + privacy_request.get_terminate_task_by_action(ActionType.consent) + + def test_request_task_address(self, request_task): + assert request_task.request_task_address == CollectionAddress( + "test_dataset", "test_collection" + ) + + def test_get_existing_request_task( + self, db, privacy_request, request_task, erasure_request_task + ): + assert ( + privacy_request.get_existing_request_task( + db, ActionType.access, request_task.request_task_address + ) + == request_task + ) + assert ( + privacy_request.get_existing_request_task( + db, ActionType.erasure, erasure_request_task.request_task_address + ) + == erasure_request_task + ) + + assert ( + privacy_request.get_existing_request_task( + db, ActionType.consent, erasure_request_task.request_task_address + ) + is None + ) + + def test_get_pending_downstream_tasks(self, db, request_task): + root_task = request_task.get_tasks_with_same_action_type( + db, ROOT_COLLECTION_ADDRESS.value + ).first() + terminator_task = request_task.get_tasks_with_same_action_type( + db, TERMINATOR_ADDRESS.value + ).first() + + assert root_task.get_pending_downstream_tasks(db).all() == [request_task] + assert request_task.get_pending_downstream_tasks(db).all() == [terminator_task] + assert terminator_task.get_pending_downstream_tasks(db).all() == [] + + @mock.patch("fides.api.util.cache.celery_app.control.inspect.query_task") + def test_request_task_running(self, query_task_mock, db, request_task): + assert request_task.request_task_running() is False + + cache_task_tracking_key(request_task.id, "test_5678") + + assert request_task.request_task_running() is False + + query_task_mock.return_value = {"@celery1234": {}} + + assert request_task.request_task_running() is False + assert request_task.can_queue_request_task(db) is True + + query_task_mock.return_value = {"@celery1234": {"test_5678": ["reserved", {}]}} + + assert request_task.request_task_running() is True + assert request_task.can_queue_request_task(db) is False + + def test_upstream_tasks_complete(self, db, request_task): + # The Request Task only has the Root Task upstream, which is complete + assert request_task.upstream_tasks == [ROOT_COLLECTION_ADDRESS.value] + assert request_task.upstream_tasks_complete(db) + + # The root Task has nothing upstream + root_task = request_task.get_tasks_with_same_action_type( + db, ROOT_COLLECTION_ADDRESS.value + ).first() + assert root_task.status == ExecutionLogStatus.complete + assert root_task.is_root_task + assert root_task.upstream_tasks_complete(db) + + # The Terminator task has the request task upstream which is pending + terminator_task = request_task.get_tasks_with_same_action_type( + db, TERMINATOR_ADDRESS.value + ).first() + assert terminator_task.is_terminator_task + assert terminator_task.upstream_tasks == [request_task.collection_address] + assert request_task.status == ExecutionLogStatus.pending + assert not terminator_task.upstream_tasks_complete(db) + assert not terminator_task.can_queue_request_task(db) + + # Set the request task to be skipped + request_task.update_status(db, ExecutionLogStatus.skipped) + # Skipped is considered to be a completed state + assert terminator_task.upstream_tasks_complete(db) + assert terminator_task.can_queue_request_task(db) + + def test_update_status(self, db, request_task): + assert request_task.status == ExecutionLogStatus.pending + request_task.update_status(db, ExecutionLogStatus.complete) + + assert request_task.status == ExecutionLogStatus.complete + + def test_save_filtered_access_results(self, db, privacy_request): + assert privacy_request.get_filtered_access_results() == {} + + privacy_request.save_filtered_access_results( + db, + results={ + "policy_rule_key": { + "test_dataset:test_collection": [ + {"name": "Jane", "address": "101 Test Town"} + ], + "test_dataset:test_collection_2": [ + {"id": 100, "email": "jane@example.com"} + ], + } + }, + ) + + assert privacy_request.get_filtered_access_results() == { + "policy_rule_key": { + "test_dataset:test_collection": [ + {"name": "Jane", "address": "101 Test Town"} + ], + "test_dataset:test_collection_2": [ + {"id": 100, "email": "jane@example.com"} + ], + } + } + + +class TestGetRawAccessResults: + def test_no_results(self, privacy_request): + assert privacy_request.get_raw_access_results() == {} + + @pytest.mark.usefixtures("request_task") + def test_request_tasks_incomplete(self, privacy_request): + assert privacy_request.get_raw_access_results() == {} + + def test_request_tasks_complete_dsr_3_0(self, db, privacy_request, request_task): + """DSR 3.0 stores results on RequestTask.access_data""" + assert request_task.get_decoded_access_data() == [] + + request_task.access_data = json.dumps( + [{"name": "Jane", "street": "102 Test Town"}], cls=CustomJSONEncoder + ) + request_task.update_status(db, ExecutionLogStatus.complete) + assert request_task.get_decoded_access_data() == [ + {"name": "Jane", "street": "102 Test Town"} + ] + + assert privacy_request.get_raw_access_results() == { + "test_dataset:test_collection": [ + {"name": "Jane", "street": "102 Test Town"} + ] + } + + def test_dsr_2_0(self, privacy_request): + """DSR 2.0 uses the cache to store results""" + cache: FidesopsRedis = get_cache() + key = f"access_request__test_dataset:test_collection" + cache.set_encoded_object( + f"{privacy_request.id}__{key}", + [{"name": "Jane", "street": "102 Test Town"}], + ) + + assert privacy_request.get_raw_access_results() == { + "test_dataset:test_collection": [ + {"name": "Jane", "street": "102 Test Town"} + ] + } + + +class TestGetRawMaskingCounts: + def test_no_results(self, privacy_request): + assert privacy_request.get_raw_masking_counts() == {} + + @pytest.mark.usefixtures("erasure_request_task") + def test_request_tasks_incomplete(self, privacy_request): + assert privacy_request.get_raw_masking_counts() == {} + + def test_request_tasks_complete_dsr_3_0( + self, db, privacy_request, erasure_request_task + ): + """DSR 3.0 stores results on RequestTask.rows_masked""" + erasure_request_task.rows_masked = 2 + erasure_request_task.update_status(db, ExecutionLogStatus.complete) + assert privacy_request.get_raw_masking_counts() == { + "test_dataset:test_collection": 2 + } + + def test_dsr_2_0(self, privacy_request): + """DSR 2.0 uses the cache to store rows_masked""" + cache: FidesopsRedis = get_cache() + key = f"erasure_request__test_dataset:test_collection" + cache.set_encoded_object(f"{privacy_request.id}__{key}", 2) + + assert privacy_request.get_raw_masking_counts() == { + "test_dataset:test_collection": 2 + } + + +class TestGetConsentResults: + @pytest.fixture(scope="function") + def consent_request_task(self, db, privacy_request): + request_task = RequestTask.create( + db, + data={ + "action_type": ActionType.consent, + "status": "pending", + "privacy_request_id": privacy_request.id, + "collection_address": "test_dataset:test_collection", + "dataset_name": "test_dataset", + "collection_name": "test_collection", + "upstream_tasks": ["__ROOT__:__ROOT__"], + "downstream_tasks": ["__TERMINATE__:__TERMINATE__"], + }, + ) + yield request_task + request_task.delete(db) + + def test_no_results(self, privacy_request): + assert privacy_request.get_consent_results() == {} + + def test_request_tasks_incomplete(self, consent_request_task, privacy_request): + assert privacy_request.get_consent_results() == {} + + def test_request_tasks_complete_dsr_3_0( + self, db, privacy_request, consent_request_task + ): + """DSR 3.0 stores results on RequestTask.rows_masked""" + consent_request_task.consent_sent = True + consent_request_task.update_status(db, ExecutionLogStatus.complete) + assert privacy_request.get_consent_results() == { + "test_dataset:test_collection": True + } + + +class TestGetDecodedDataForErasures: + def test_no_data(self, request_task): + assert request_task.get_decoded_data_for_erasures() == [] + + def test_request_task_has_erasure_data(self, db, request_task): + request_task.data_for_erasures = json.dumps( + [{"id": 1, "name": "Jane"}], cls=CustomJSONEncoder + ) + request_task.save(db) + + assert request_task.get_decoded_data_for_erasures() == [ + {"id": 1, "name": "Jane"} + ] diff --git a/tests/ops/service/connectors/test_fides_connector.py b/tests/ops/service/connectors/test_fides_connector.py index 9badd17113..4c46d5db9b 100644 --- a/tests/ops/service/connectors/test_fides_connector.py +++ b/tests/ops/service/connectors/test_fides_connector.py @@ -8,7 +8,11 @@ from fides.api.models.connectionconfig import ConnectionConfig, ConnectionTestStatus from fides.api.models.datasetconfig import DatasetConfig from fides.api.models.policy import Policy -from fides.api.models.privacy_request import PrivacyRequest, PrivacyRequestStatus +from fides.api.models.privacy_request import ( + PrivacyRequest, + PrivacyRequestStatus, + RequestTask, +) from fides.api.schemas.policy import ActionType from fides.api.service.connectors.fides.fides_client import FidesClient from fides.api.service.connectors.fides_connector import ( @@ -137,6 +141,7 @@ def test_retrieve_data( node=node, policy=policy_local_storage, privacy_request=privacy_request, + request_task=RequestTask(), input_data=[], ) diff --git a/tests/ops/service/connectors/test_queryconfig.py b/tests/ops/service/connectors/test_queryconfig.py index a73de222b1..8ff2fd99ae 100644 --- a/tests/ops/service/connectors/test_queryconfig.py +++ b/tests/ops/service/connectors/test_queryconfig.py @@ -12,6 +12,7 @@ ObjectField, ScalarField, ) +from fides.api.graph.execution import ExecutionNode from fides.api.graph.graph import DatasetGraph, Edge from fides.api.graph.traversal import Traversal, TraversalNode from fides.api.models.datasetconfig import convert_dataset_to_graph @@ -39,16 +40,23 @@ graph: DatasetGraph = integration_db_graph("postgres_example") traversal = Traversal(graph, {"email": "X"}) traversal_nodes: Dict[CollectionAddress, TraversalNode] = traversal.traversal_node_dict -payment_card_node = traversal_nodes[ +payment_card_traversal_node = traversal_nodes[ CollectionAddress("postgres_example", "payment_card") ] -user_node = traversal_nodes[CollectionAddress("postgres_example", "payment_card")] +payment_card_request_task = payment_card_traversal_node.to_mock_request_task() +payment_card_node: ExecutionNode = ExecutionNode(payment_card_request_task) + +user_traversal_node = traversal_nodes[ + CollectionAddress("postgres_example", "payment_card") +] +user_request_task = user_traversal_node.to_mock_request_task() +user_node = ExecutionNode(user_request_task) privacy_request = PrivacyRequest(id="234544") class TestSQLQueryConfig: def test_extract_query_components(self): - def found_query_keys(node: TraversalNode, values: Dict[str, Any]) -> Set[str]: + def found_query_keys(node: ExecutionNode, values: Dict[str, Any]) -> Set[str]: return set(node.typed_filtered_values(values).keys()) config = SQLQueryConfig(payment_card_node) @@ -177,7 +185,7 @@ def test_update_rule_target_fields( customer_node = traversal.traversal_node_dict[ CollectionAddress("postgres_example_test_dataset", "customer") - ] + ].to_mock_execution_node() rule = erasure_policy.rules[0] config = SQLQueryConfig(customer_node) @@ -195,7 +203,7 @@ def test_update_rule_target_fields( # Check different collection address_node = traversal.traversal_node_dict[ CollectionAddress("postgres_example_test_dataset", "address") - ] + ].to_mock_execution_node() config = SQLQueryConfig(address_node) assert config.build_rule_target_field_paths(erasure_policy) == { rule: [FieldPath(x) for x in ["city", "house", "street", "state", "zip"]] @@ -211,7 +219,7 @@ def test_generate_update_stmt_one_field( customer_node = traversal.traversal_node_dict[ CollectionAddress("postgres_example_test_dataset", "customer") - ] + ].to_mock_execution_node() config = SQLQueryConfig(customer_node) row = { @@ -238,7 +246,7 @@ def test_generate_update_stmt_length_truncation( customer_node = traversal.traversal_node_dict[ CollectionAddress("postgres_example_test_dataset", "customer") - ] + ].to_mock_execution_node() config = SQLQueryConfig(customer_node) row = { @@ -269,7 +277,7 @@ def test_generate_update_stmt_multiple_fields_same_rule( customer_node = traversal.traversal_node_dict[ CollectionAddress("postgres_example_test_dataset", "customer") - ] + ].to_mock_execution_node() config = SQLQueryConfig(customer_node) row = { @@ -334,7 +342,7 @@ def test_generate_update_stmts_from_multiple_rules( customer_node = traversal.traversal_node_dict[ CollectionAddress("postgres_example_test_dataset", "customer") - ] + ].to_mock_execution_node() config = SQLQueryConfig(customer_node) @@ -370,13 +378,13 @@ def combined_traversal(self, connection_config, integration_mongodb_config): def customer_details_node(self, combined_traversal): return combined_traversal.traversal_node_dict[ CollectionAddress("mongo_test", "customer_details") - ] + ].to_mock_execution_node() @pytest.fixture(scope="function") def customer_feedback_node(self, combined_traversal): return combined_traversal.traversal_node_dict[ CollectionAddress("mongo_test", "customer_feedback") - ] + ].to_mock_execution_node() def test_field_map_nested(self, customer_details_node): config = MongoQueryConfig(customer_details_node) @@ -442,7 +450,7 @@ def test_generate_query( # Test query on nested field customer_feedback = traversal.traversal_node_dict[ CollectionAddress("mongo_test", "customer_feedback") - ] + ].to_mock_execution_node() config = MongoQueryConfig(customer_feedback) input_data = {"customer_information.email": ["customer-1@example.com"]} # Tuple of query, projection - Searching for documents with nested @@ -455,7 +463,7 @@ def test_generate_query( # Test query nested data customer_details = traversal.traversal_node_dict[ CollectionAddress("mongo_test", "customer_details") - ] + ].to_mock_execution_node() config = MongoQueryConfig(customer_details) input_data = {"customer_id": [1]} # Tuple of query, projection - Projection is specifying fields at the top-level. Nested data will @@ -493,7 +501,7 @@ def test_generate_update_stmt_multiple_fields( traversal = Traversal(dataset_graph, {"email": "customer-1@example.com"}) customer_details = traversal.traversal_node_dict[ CollectionAddress("mongo_test", "customer_details") - ] + ].to_mock_execution_node() config = MongoQueryConfig(customer_details) row = { "birthday": "1988-01-10", @@ -557,7 +565,7 @@ def test_generate_update_stmt_multiple_rules( customer_details = traversal.traversal_node_dict[ CollectionAddress("mongo_test", "customer_details") - ] + ].to_mock_execution_node() config = MongoQueryConfig(customer_details) row = { @@ -633,13 +641,13 @@ def traversal(self, identity, dataset_graph): def customer_node(self, traversal): return traversal.traversal_node_dict[ CollectionAddress("dynamodb_example_test_dataset", "customer") - ] + ].to_mock_execution_node() @pytest.fixture(scope="function") def customer_identifier_node(self, traversal): return traversal.traversal_node_dict[ CollectionAddress("dynamodb_example_test_dataset", "customer_identifier") - ] + ].to_mock_execution_node() @pytest.fixture(scope="function") def customer_row(self): diff --git a/tests/ops/service/connectors/test_saas_connector.py b/tests/ops/service/connectors/test_saas_connector.py index 52f656ba62..4c0c85cc5a 100644 --- a/tests/ops/service/connectors/test_saas_connector.py +++ b/tests/ops/service/connectors/test_saas_connector.py @@ -10,6 +10,7 @@ from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND from fides.api.common_exceptions import SkippingConsentPropagation +from fides.api.graph.execution import ExecutionNode from fides.api.graph.graph import Node from fides.api.graph.traversal import TraversalNode from fides.api.models.policy import Policy @@ -146,9 +147,11 @@ def test_delete_only_endpoint( ), ) traversal_node = TraversalNode(node) + request_task = traversal_node.to_mock_request_task() + execution_node = ExecutionNode(request_task) connector: SaaSConnector = get_connector(saas_example_connection_config) assert connector.retrieve_data( - traversal_node, Policy(), PrivacyRequest(id="123"), {} + execution_node, Policy(), PrivacyRequest(id="123"), request_task, {} ) == [{}] @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") @@ -176,6 +179,9 @@ def test_input_values( ), ) traversal_node = TraversalNode(node) + request_task = traversal_node.to_mock_request_task() + execution_node = ExecutionNode(request_task) + connector: SaaSConnector = get_connector(saas_example_connection_config) # this request requires the email identity in the filter postprocessor so we include it here @@ -183,9 +189,10 @@ def test_input_values( privacy_request.cache_identity(Identity(email="test@example.com")) assert connector.retrieve_data( - traversal_node, + execution_node, Policy(), privacy_request, + request_task, {"fidesops_grouped_inputs": [], "conversation_id": ["456"]}, ) == [{"id": "123", "from_email": "test@example.com"}] @@ -208,10 +215,12 @@ def test_missing_input_values( ), ) traversal_node = TraversalNode(node) + request_task = traversal_node.to_mock_request_task() + execution_node = ExecutionNode(request_task) connector: SaaSConnector = get_connector(saas_example_connection_config) assert ( connector.retrieve_data( - traversal_node, Policy(), PrivacyRequest(id="123"), {} + execution_node, Policy(), PrivacyRequest(id="123"), request_task, {} ) == [] ) @@ -239,11 +248,14 @@ def test_grouped_input_values( ), ) traversal_node = TraversalNode(node) + request_task = traversal_node.to_mock_request_task() + execution_node = ExecutionNode(request_task) connector: SaaSConnector = get_connector(saas_example_connection_config) assert connector.retrieve_data( - traversal_node, + execution_node, Policy(), PrivacyRequest(id="123"), + request_task, { "fidesops_grouped_inputs": [ { @@ -273,10 +285,12 @@ def test_missing_grouped_inputs_input_values( ), ) traversal_node = TraversalNode(node) + request_task = traversal_node.to_mock_request_task() + execution_node = ExecutionNode(request_task) connector: SaaSConnector = get_connector(saas_example_connection_config) assert ( connector.retrieve_data( - traversal_node, Policy(), PrivacyRequest(id="123"), {} + execution_node, Policy(), PrivacyRequest(id="123"), request_task, {} ) == [] ) @@ -306,16 +320,18 @@ def test_skip_missing_param_values_masking( ) traversal_node = TraversalNode(node) + request_task = traversal_node.to_mock_request_task() + execution_node = ExecutionNode(request_task) connector: SaaSConnector = get_connector(saas_example_connection_config) # Base case - we can populate all placeholders in request body assert ( connector.mask_data( - traversal_node, + execution_node, Policy(), PrivacyRequest(id="123"), - {"customer_id": 1}, - {"phone_number": "555-555-5555"}, + request_task, + [{"customer_id": 1}], ) == 1 ) @@ -330,11 +346,11 @@ def test_skip_missing_param_values_masking( # Should raise ValueError because we don't have email value for request body with pytest.raises(ValueError): connector.mask_data( - traversal_node, + execution_node, Policy(), PrivacyRequest(id="123"), - {"customer_id": 1}, - {"phone_number": "555-555-5555"}, + request_task, + [{"customer_id": 1}], ) # Set skip_missing_param_values to True, so the missing placeholder just causes the request to be skipped @@ -343,11 +359,11 @@ def test_skip_missing_param_values_masking( ].requests.update.skip_missing_param_values = True assert ( connector.mask_data( - traversal_node, + execution_node, Policy(), PrivacyRequest(id="123"), - {"customer_id": 1}, - {"phone_number": "555-555-5555"}, + request_task, + [{"customer_id": 1}], ) == 0 ) @@ -430,11 +446,15 @@ def test_no_preferences_to_propagate( ): connector = get_connector(mailchimp_transactional_connection_config_no_secrets) with pytest.raises(SkippingConsentPropagation) as exc: + traversal_node = TraversalNode(generate_node("a", "b", "c", "c2")) + request_task = traversal_node.to_mock_request_task() + execution_node = traversal_node.to_mock_execution_node() connector.run_consent_request( - node=TraversalNode(generate_node("a", "b", "c", "c2")), + node=execution_node, policy=consent_policy, privacy_request=privacy_request_with_consent_policy, identity_data={"ljt_readerID": "abcde"}, + request_task=request_task, session=db, ) assert "no actionable consent preferences to propagate" in str(exc) @@ -457,11 +477,15 @@ def test_data_use_mismatch( connector = get_connector(mailchimp_transactional_connection_config_no_secrets) with pytest.raises(SkippingConsentPropagation) as exc: + traversal_node = TraversalNode(generate_node("a", "b", "c", "c2")) + request_task = traversal_node.to_mock_request_task() + execution_node = traversal_node.to_mock_execution_node() connector.run_consent_request( - node=TraversalNode(generate_node("a", "b", "c", "c2")), + node=execution_node, policy=consent_policy, privacy_request=privacy_request_with_consent_policy, identity_data={"ljt_readerID": "abcde"}, + request_task=request_task, session=db, ) @@ -488,11 +512,15 @@ def test_enforcement_level_not_system_wide( connector = get_connector(mailchimp_transactional_connection_config_no_secrets) with pytest.raises(SkippingConsentPropagation) as exc: + traversal_node = TraversalNode(generate_node("a", "b", "c", "c2")) + request_task = traversal_node.to_mock_request_task() + execution_node = traversal_node.to_mock_execution_node() connector.run_consent_request( - node=TraversalNode(generate_node("a", "b", "c", "c2")), + node=execution_node, policy=consent_policy, privacy_request=privacy_request_with_consent_policy, identity_data={"ljt_readerID": "abcde"}, + request_task=request_task, session=db, ) assert "no actionable consent preferences to propagate" in str(exc) @@ -525,11 +553,15 @@ def test_missing_identity_data_failure( connector = get_connector(mailchimp_transactional_connection_config_no_secrets) with pytest.raises(ValueError): + traversal_node = TraversalNode(generate_node("a", "b", "c", "c2")) + request_task = traversal_node.to_mock_request_task() + execution_node = traversal_node.to_mock_execution_node() connector.run_consent_request( - node=TraversalNode(generate_node("a", "b", "c", "c2")), + node=execution_node, policy=consent_policy, privacy_request=privacy_request, identity_data={"ljt_readerID": "abcde"}, + request_task=request_task, session=db, ) @@ -562,11 +594,15 @@ def test_missing_identity_data_skipped( connector = get_connector(google_analytics_connection_config_without_secrets) with pytest.raises(SkippingConsentPropagation) as exc: + traversal_node = TraversalNode(generate_node("a", "b", "c", "c2")) + request_task = traversal_node.to_mock_request_task() + execution_node = traversal_node.to_mock_execution_node() connector.run_consent_request( - node=TraversalNode(generate_node("a", "b", "c", "c2")), + node=execution_node, policy=consent_policy, privacy_request=privacy_request, identity_data={"ljt_readerID": "abcde"}, + request_task=request_task, session=db, ) @@ -599,11 +635,15 @@ def test_no_requests_of_that_type_defined( connector = get_connector(google_analytics_connection_config_without_secrets) with pytest.raises(SkippingConsentPropagation) as exc: + traversal_node = TraversalNode(generate_node("a", "b", "c", "c2")) + request_task = traversal_node.to_mock_request_task() + execution_node = traversal_node.to_mock_execution_node() connector.run_consent_request( - node=TraversalNode(generate_node("a", "b", "c", "c2")), + node=execution_node, policy=consent_policy, privacy_request=privacy_request, identity_data={"ljt_readerID": "abcde"}, + request_task=request_task, session=db, ) @@ -629,11 +669,15 @@ def test_preferences_executable( privacy_preference_history.save(db) connector = get_connector(mailchimp_transactional_connection_config_no_secrets) + traversal_node = TraversalNode(generate_node("a", "b", "c", "c2")) + request_task = traversal_node.to_mock_request_task() + execution_node = traversal_node.to_mock_execution_node() connector.run_consent_request( - node=TraversalNode(generate_node("a", "b", "c", "c2")), + node=execution_node, policy=consent_policy, privacy_request=privacy_request_with_consent_policy, identity_data={"ljt_readerID": "abcde"}, + request_task=request_task, session=db, ) assert mock_send.called diff --git a/tests/ops/service/connectors/test_saas_queryconfig.py b/tests/ops/service/connectors/test_saas_queryconfig.py index 546734594d..c82f264ee0 100644 --- a/tests/ops/service/connectors/test_saas_queryconfig.py +++ b/tests/ops/service/connectors/test_saas_queryconfig.py @@ -82,7 +82,7 @@ def test_generate_requests( ] payment_methods = combined_traversal.traversal_node_dict[ CollectionAddress(saas_config.fides_key, "payment_methods") - ] + ].to_mock_execution_node() # static path with single query param config = SaaSQueryConfig( @@ -179,7 +179,7 @@ def test_generate_update_stmt( member = combined_traversal.traversal_node_dict[ CollectionAddress(saas_config.fides_key, "member") - ] + ].to_mock_execution_node() config = SaaSQueryConfig(member, endpoints, {}, update_request) row = { @@ -216,7 +216,7 @@ def test_generate_update_stmt_custom_http_method( member = combined_traversal.traversal_node_dict[ CollectionAddress(saas_config.fides_key, "member") - ] + ].to_mock_execution_node() update_request = endpoints["member"].requests.update config = SaaSQueryConfig(member, endpoints, {}, update_request) @@ -270,10 +270,10 @@ def test_generate_update_stmt_with_request_body( update_request = endpoints["member"].requests.update member = combined_traversal.traversal_node_dict[ CollectionAddress(saas_config.fides_key, "member") - ] + ].to_mock_execution_node() payment_methods = combined_traversal.traversal_node_dict[ CollectionAddress(saas_config.fides_key, "payment_methods") - ] + ].to_mock_execution_node() config = SaaSQueryConfig(member, endpoints, {}, update_request) row = { @@ -329,7 +329,7 @@ def test_generate_update_stmt_with_url_encoded_body( endpoints = saas_config.top_level_endpoint_dict customer = combined_traversal.traversal_node_dict[ CollectionAddress(saas_config.fides_key, "customer") - ] + ].to_mock_execution_node() # update with multidimensional urlcoding # omit read-only fields and fields not defined in the dataset @@ -652,7 +652,7 @@ def test_custom_privacy_request_fields( internal_information = combined_traversal.traversal_node_dict[ CollectionAddress(saas_config.fides_key, "internal_information") - ] + ].to_mock_execution_node() config = SaaSQueryConfig( internal_information, diff --git a/tests/ops/service/privacy_request/test_request_runner_service.py b/tests/ops/service/privacy_request/test_request_runner_service.py index 51f9f9bca5..8f07b7746d 100644 --- a/tests/ops/service/privacy_request/test_request_runner_service.py +++ b/tests/ops/service/privacy_request/test_request_runner_service.py @@ -81,13 +81,21 @@ def privacy_request_complete_email_notification_enabled(db): @mock.patch("fides.api.service.privacy_request.request_runner_service.dispatch_message") @mock.patch("fides.api.service.privacy_request.request_runner_service.upload") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_policy_upload_dispatch_message_called( upload_mock: Mock, mock_email_dispatch: Mock, privacy_request_status_pending: PrivacyRequest, run_privacy_request_task, + dsr_version, + request, privacy_request_complete_email_notification_enabled, ) -> None: + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + upload_mock.return_value = "http://www.data-download-url" run_privacy_request_task.delay(privacy_request_status_pending.id).get( timeout=PRIVACY_REQUEST_TASK_TIMEOUT @@ -98,13 +106,21 @@ def test_policy_upload_dispatch_message_called( @mock.patch("fides.api.service.privacy_request.request_runner_service.dispatch_message") @mock.patch("fides.api.service.privacy_request.request_runner_service.upload") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_complete_email_not_sent_if_consent_request( upload_mock: Mock, mock_email_dispatch: Mock, privacy_request_with_consent_policy: PrivacyRequest, run_privacy_request_task, + dsr_version, + request, privacy_request_complete_email_notification_enabled, ) -> None: + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + upload_mock.return_value = "http://www.data-download-url" run_privacy_request_task.delay(privacy_request_with_consent_policy.id).get( timeout=PRIVACY_REQUEST_TASK_TIMEOUT @@ -114,14 +130,22 @@ def test_complete_email_not_sent_if_consent_request( @mock.patch("fides.api.service.privacy_request.request_runner_service.dispatch_message") @mock.patch("fides.api.service.privacy_request.request_runner_service.upload") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_start_processing_sets_started_processing_at( upload_mock: Mock, mock_email_dispatch: Mock, db: Session, privacy_request_status_pending: PrivacyRequest, run_privacy_request_task, + request, + dsr_version, privacy_request_complete_email_notification_enabled, ) -> None: + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + upload_mock.return_value = "http://www.data-download-url" updated_at = privacy_request_status_pending.updated_at assert privacy_request_status_pending.started_processing_at is None @@ -138,14 +162,22 @@ def test_start_processing_sets_started_processing_at( @mock.patch("fides.api.service.privacy_request.request_runner_service.dispatch_message") @mock.patch("fides.api.service.privacy_request.request_runner_service.upload") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_start_processing_doesnt_overwrite_started_processing_at( upload_mock: Mock, mock_email_dispatch: Mock, db: Session, privacy_request: PrivacyRequest, run_privacy_request_task, + request, + dsr_version, privacy_request_complete_email_notification_enabled, ) -> None: + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + upload_mock.return_value = "http://www.data-download-url" before = privacy_request.started_processing_at assert before is not None @@ -165,13 +197,21 @@ def test_start_processing_doesnt_overwrite_started_processing_at( @mock.patch( "fides.api.service.privacy_request.request_runner_service.upload_access_results" ) +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_halts_proceeding_if_cancelled( upload_access_results_mock, db: Session, privacy_request_status_canceled: PrivacyRequest, run_privacy_request_task, + dsr_version, + request, privacy_request_complete_email_notification_enabled, ) -> None: + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + assert privacy_request_status_canceled.status == PrivacyRequestStatus.canceled run_privacy_request_task.delay(privacy_request_status_canceled.id).get( timeout=PRIVACY_REQUEST_TASK_TIMEOUT @@ -192,10 +232,12 @@ def test_halts_proceeding_if_cancelled( @mock.patch( "fides.api.service.privacy_request.request_runner_service.run_webhooks_and_report_status", ) -@mock.patch( - "fides.api.service.privacy_request.request_runner_service.run_access_request" +@mock.patch("fides.api.service.privacy_request.request_runner_service.access_runner") +@mock.patch("fides.api.service.privacy_request.request_runner_service.erasure_runner") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], ) -@mock.patch("fides.api.service.privacy_request.request_runner_service.run_erasure") def test_from_graph_resume_does_not_run_pre_webhooks( run_erasure, run_access, @@ -206,8 +248,12 @@ def test_from_graph_resume_does_not_run_pre_webhooks( privacy_request: PrivacyRequest, run_privacy_request_task, erasure_policy, + dsr_version, + request, privacy_request_complete_email_notification_enabled, ) -> None: + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + upload_mock.return_value = "http://www.data-download-url" privacy_request.started_processing_at = None privacy_request.policy = erasure_policy @@ -237,10 +283,8 @@ def test_from_graph_resume_does_not_run_pre_webhooks( @mock.patch( "fides.api.service.privacy_request.request_runner_service.run_webhooks_and_report_status", ) -@mock.patch( - "fides.api.service.privacy_request.request_runner_service.run_access_request" -) -@mock.patch("fides.api.service.privacy_request.request_runner_service.run_erasure") +@mock.patch("fides.api.service.privacy_request.request_runner_service.access_runner") +@mock.patch("fides.api.service.privacy_request.request_runner_service.erasure_runner") def test_resume_privacy_request_from_erasure( run_erasure, run_access, @@ -334,18 +378,28 @@ def get_privacy_request_results( @pytest.mark.integration_postgres @pytest.mark.integration @mock.patch("fides.api.service.privacy_request.request_runner_service.upload") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_upload_access_results_has_data_category_field_mapping( upload_mock: Mock, postgres_example_test_dataset_config_read_access, postgres_integration_db, db, policy, + dsr_version, + request, run_privacy_request_task, ): """ Ensure we are passing along a correctly populated data_category_field_mapping to the 'upload' function that publishes the access request output. """ + upload_mock.return_value = "http://www.data-download-url" + + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-1@example.com" data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -361,7 +415,7 @@ def test_upload_access_results_has_data_category_field_mapping( ) # sanity check that acccess results returned as expected - results = pr.get_results() + results = pr.get_raw_access_results() assert len(results.keys()) == 11 # what we're really testing - ensure data_category_field_mapping arg is well-populated @@ -397,18 +451,28 @@ def test_upload_access_results_has_data_category_field_mapping( @pytest.mark.integration_postgres @pytest.mark.integration @mock.patch("fides.api.service.privacy_request.request_runner_service.upload") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_upload_access_results_has_data_use_map( upload_mock: Mock, postgres_example_test_dataset_config_read_access, postgres_integration_db, db, policy, + dsr_version, + request, run_privacy_request_task, ): """ Ensure we are passing along a correctly populated data_use_map to the 'upload' function that publishes the access request output. """ + upload_mock.return_value = "http://www.data-download-url" + + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-1@example.com" data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -423,8 +487,8 @@ def test_upload_access_results_has_data_use_map( data, ) - # sanity check that acccess results returned as expected - results = pr.get_results() + # sanity check that access results returned as expected + results = pr.get_raw_access_results() assert len(results.keys()) == 11 # what we're really testing - ensure data_use_map arg is well-populated @@ -449,17 +513,25 @@ def test_upload_access_results_has_data_use_map( @pytest.mark.integration_postgres @pytest.mark.integration @mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_access_request_postgres( trigger_webhook_mock, postgres_example_test_dataset_config_read_access, postgres_integration_db, db, cache, + dsr_version, + request, policy, policy_pre_execution_webhooks, policy_post_execution_webhooks, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-1@example.com" data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -474,14 +546,14 @@ def test_create_and_process_access_request_postgres( data, ) - results = pr.get_results() + results = pr.get_raw_access_results() assert len(results.keys()) == 11 for key in results.keys(): assert results[key] is not None assert results[key] != {} - result_key_prefix = f"EN_{pr.id}__access_request__postgres_example_test_dataset:" + result_key_prefix = f"postgres_example_test_dataset:" customer_key = result_key_prefix + "customer" assert results[customer_key][0]["email"] == customer_email @@ -517,6 +589,10 @@ def test_create_and_process_access_request_postgres( @pytest.mark.integration_postgres @pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) @mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") def test_create_and_process_access_request_with_custom_identities_postgres( trigger_webhook_mock, @@ -526,10 +602,14 @@ def test_create_and_process_access_request_with_custom_identities_postgres( db, cache, policy, + dsr_version, + request, policy_pre_execution_webhooks, policy_post_execution_webhooks, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-1@example.com" loyalty_id = "CH-1" data = { @@ -548,23 +628,21 @@ def test_create_and_process_access_request_with_custom_identities_postgres( data, ) - results = pr.get_results() + results = pr.get_raw_access_results() assert len(results.keys()) == 12 for key in results.keys(): assert results[key] is not None assert results[key] != {} - result_key_prefix = f"EN_{pr.id}__access_request__postgres_example_test_dataset:" + result_key_prefix = f"postgres_example_test_dataset:" customer_key = result_key_prefix + "customer" assert results[customer_key][0]["email"] == customer_email visit_key = result_key_prefix + "visit" assert results[visit_key][0]["email"] == customer_email - loyalty_key = ( - f"EN_{pr.id}__access_request__postgres_example_test_extended_dataset:loyalty" - ) + loyalty_key = f"postgres_example_test_extended_dataset:loyalty" assert results[loyalty_key][0]["id"] == loyalty_id log_id = pr.execution_logs[0].id @@ -602,11 +680,19 @@ def test_create_and_process_access_request_with_custom_identities_postgres( "postgres_integration_db", "cache", ) +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_access_request_with_valid_skipped_collection( db, policy, run_privacy_request_task, + dsr_version, + request, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-1@example.com" data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -621,12 +707,12 @@ def test_create_and_process_access_request_with_valid_skipped_collection( data, ) - results = pr.get_results() + results = pr.get_raw_access_results() assert len(results.keys()) == 10 assert "login" not in results.keys() - result_key_prefix = f"EN_{pr.id}__access_request__postgres_example_test_dataset:" + result_key_prefix = f"postgres_example_test_dataset:" customer_key = result_key_prefix + "customer" assert results[customer_key][0]["email"] == customer_email @@ -646,11 +732,19 @@ def test_create_and_process_access_request_with_valid_skipped_collection( "postgres_integration_db", "cache", ) +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_access_request_with_invalid_skipped_collection( db, policy, + dsr_version, + request, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-1@example.com" data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -665,7 +759,7 @@ def test_create_and_process_access_request_with_invalid_skipped_collection( data, ) - results = pr.get_results() + results = pr.get_raw_access_results() assert len(results.keys()) == 0 db.refresh(pr) @@ -675,6 +769,10 @@ def test_create_and_process_access_request_with_invalid_skipped_collection( @pytest.mark.integration @mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_access_request_mssql( trigger_webhook_mock, mssql_example_test_dataset_config, @@ -682,10 +780,14 @@ def test_create_and_process_access_request_mssql( db, cache, policy, + dsr_version, + request, policy_pre_execution_webhooks, policy_post_execution_webhooks, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-1@example.com" data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -700,14 +802,14 @@ def test_create_and_process_access_request_mssql( data, ) - results = pr.get_results() + results = pr.get_raw_access_results() assert len(results.keys()) == 11 for key in results.keys(): assert results[key] is not None assert results[key] != {} - result_key_prefix = f"EN_{pr.id}__access_request__mssql_example_test_dataset:" + result_key_prefix = f"mssql_example_test_dataset:" customer_key = result_key_prefix + "customer" assert results[customer_key][0]["email"] == customer_email @@ -720,6 +822,10 @@ def test_create_and_process_access_request_mssql( @pytest.mark.integration @mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_access_request_mysql( trigger_webhook_mock, mysql_example_test_dataset_config, @@ -727,10 +833,14 @@ def test_create_and_process_access_request_mysql( db, cache, policy, + dsr_version, + request, policy_pre_execution_webhooks, policy_post_execution_webhooks, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-1@example.com" data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -745,14 +855,14 @@ def test_create_and_process_access_request_mysql( data, ) - results = pr.get_results() + results = pr.get_raw_access_results() assert len(results.keys()) == 11 for key in results.keys(): assert results[key] is not None assert results[key] != {} - result_key_prefix = f"EN_{pr.id}__access_request__mysql_example_test_dataset:" + result_key_prefix = f"mysql_example_test_dataset:" customer_key = result_key_prefix + "customer" assert results[customer_key][0]["email"] == customer_email @@ -766,6 +876,10 @@ def test_create_and_process_access_request_mysql( @pytest.mark.integration_mariadb @pytest.mark.integration @mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_access_request_mariadb( trigger_webhook_mock, mariadb_example_test_dataset_config, @@ -773,10 +887,14 @@ def test_create_and_process_access_request_mariadb( db, cache, policy, + dsr_version, + request, policy_pre_execution_webhooks, policy_post_execution_webhooks, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-1@example.com" data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -791,14 +909,14 @@ def test_create_and_process_access_request_mariadb( data, ) - results = pr.get_results() + results = pr.get_raw_access_results() assert len(results.keys()) == 11 for key in results.keys(): assert results[key] is not None assert results[key] != {} - result_key_prefix = f"EN_{pr.id}__access_request__mariadb_example_test_dataset:" + result_key_prefix = "mariadb_example_test_dataset:" customer_key = result_key_prefix + "customer" assert results[customer_key][0]["email"] == customer_email @@ -811,6 +929,10 @@ def test_create_and_process_access_request_mariadb( @pytest.mark.integration_saas @mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_access_request_saas_mailchimp( trigger_webhook_mock, mailchimp_connection_config, @@ -820,9 +942,13 @@ def test_create_and_process_access_request_saas_mailchimp( policy, policy_pre_execution_webhooks, policy_post_execution_webhooks, + dsr_version, + request, mailchimp_identity_email, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = mailchimp_identity_email data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -837,14 +963,14 @@ def test_create_and_process_access_request_saas_mailchimp( data, task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, ) - results = pr.get_results() + results = pr.get_raw_access_results() assert len(results.keys()) == 3 for key in results.keys(): assert results[key] is not None assert results[key] != {} - result_key_prefix = f"EN_{pr.id}__access_request__mailchimp_instance:" + result_key_prefix = f"mailchimp_instance:" member_key = result_key_prefix + "member" assert results[member_key][0]["email_address"] == customer_email @@ -856,6 +982,10 @@ def test_create_and_process_access_request_saas_mailchimp( @pytest.mark.integration_saas @mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_erasure_request_saas( _, mailchimp_connection_config, @@ -864,10 +994,14 @@ def test_create_and_process_erasure_request_saas( cache, erasure_policy_hmac, generate_auth_header, + dsr_version, + request, mailchimp_identity_email, reset_mailchimp_data, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = mailchimp_identity_email data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -917,6 +1051,10 @@ def test_create_and_process_erasure_request_saas( @pytest.mark.integration_saas @mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_access_request_saas_hubspot( trigger_webhook_mock, connection_config_hubspot, @@ -926,9 +1064,13 @@ def test_create_and_process_access_request_saas_hubspot( policy, policy_pre_execution_webhooks, policy_post_execution_webhooks, + dsr_version, + request, hubspot_identity_email, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = hubspot_identity_email data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -943,14 +1085,14 @@ def test_create_and_process_access_request_saas_hubspot( data, task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, ) - results = pr.get_results() + results = pr.get_raw_access_results() assert len(results.keys()) == 4 for key in results.keys(): assert results[key] is not None assert results[key] != {} - result_key_prefix = f"EN_{pr.id}__access_request__hubspot_instance:" + result_key_prefix = f"hubspot_instance:" contacts_key = result_key_prefix + "contacts" assert results[contacts_key][0]["properties"]["email"] == customer_email @@ -962,6 +1104,10 @@ def test_create_and_process_access_request_saas_hubspot( @pytest.mark.integration_postgres @pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_erasure_request_specific_category_postgres( postgres_integration_db, postgres_example_test_dataset_config, @@ -969,9 +1115,13 @@ def test_create_and_process_erasure_request_specific_category_postgres( db, generate_auth_header, erasure_policy, + dsr_version, + request, read_connection_config, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-1@example.com" customer_id = 1 data = { @@ -1008,15 +1158,23 @@ def test_create_and_process_erasure_request_specific_category_postgres( @pytest.mark.integration_mssql @pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_erasure_request_specific_category_mssql( mssql_integration_db, mssql_example_test_dataset_config, cache, db, + dsr_version, + request, generate_auth_header, erasure_policy, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-1@example.com" customer_id = 1 data = { @@ -1050,15 +1208,23 @@ def test_create_and_process_erasure_request_specific_category_mssql( @pytest.mark.integration_mysql @pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_erasure_request_specific_category_mysql( mysql_integration_db, mysql_example_test_dataset_config, cache, db, + dsr_version, + request, generate_auth_header, erasure_policy, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-1@example.com" customer_id = 1 data = { @@ -1092,15 +1258,23 @@ def test_create_and_process_erasure_request_specific_category_mysql( @pytest.mark.integration_mariadb @pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_erasure_request_specific_category_mariadb( mariadb_example_test_dataset_config, mariadb_integration_db, cache, db, + dsr_version, + request, generate_auth_header, erasure_policy, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-1@example.com" customer_id = 1 data = { @@ -1134,15 +1308,23 @@ def test_create_and_process_erasure_request_specific_category_mariadb( @pytest.mark.integration_postgres @pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_erasure_request_generic_category( postgres_integration_db, postgres_example_test_dataset_config, cache, db, + dsr_version, + request, generate_auth_header, erasure_policy, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + # It's safe to change this here since the `erasure_policy` fixture is scoped # at function level target = erasure_policy.rules[0].targets[0] @@ -1189,15 +1371,23 @@ def test_create_and_process_erasure_request_generic_category( @pytest.mark.integration_postgres @pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_erasure_request_aes_generic_category( postgres_integration_db, postgres_example_test_dataset_config, cache, db, + dsr_version, + request, generate_auth_header, erasure_policy_aes, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + # It's safe to change this here since the `erasure_policy` fixture is scoped # at function level target = erasure_policy_aes.rules[0].targets[0] @@ -1246,14 +1436,22 @@ def test_create_and_process_erasure_request_aes_generic_category( @pytest.mark.integration_postgres @pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_erasure_request_with_table_joins( postgres_integration_db, postgres_example_test_dataset_config, db, cache, + dsr_version, + request, erasure_policy, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + # It's safe to change this here since the `erasure_policy` fixture is scoped # at function level target = erasure_policy.rules[0].targets[0] @@ -1298,14 +1496,22 @@ def test_create_and_process_erasure_request_with_table_joins( @pytest.mark.integration_postgres @pytest.mark.integration +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_erasure_request_read_access( postgres_integration_db, postgres_example_test_dataset_config_read_access, db, cache, erasure_policy, + dsr_version, + request, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-2@example.com" customer_id = 2 data = { @@ -1384,13 +1590,21 @@ def snowflake_resources( @pytest.mark.integration_external @pytest.mark.integration_snowflake +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_access_request_snowflake( snowflake_resources, db, cache, policy, + dsr_version, + request, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = snowflake_resources["email"] customer_name = snowflake_resources["name"] data = { @@ -1405,10 +1619,8 @@ def test_create_and_process_access_request_snowflake( data, task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, ) - results = pr.get_results() - customer_table_key = ( - f"EN_{pr.id}__access_request__snowflake_example_test_dataset:customer" - ) + results = pr.get_raw_access_results() + customer_table_key = f"snowflake_example_test_dataset:customer" assert len(results[customer_table_key]) == 1 assert results[customer_table_key][0]["email"] == customer_email assert results[customer_table_key][0]["name"] == customer_name @@ -1418,15 +1630,23 @@ def test_create_and_process_access_request_snowflake( @pytest.mark.integration_external @pytest.mark.integration_snowflake +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_erasure_request_snowflake( snowflake_example_test_dataset_config, snowflake_resources, integration_config: Dict[str, str], db, cache, + dsr_version, + request, erasure_policy, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = snowflake_resources["email"] snowflake_client = snowflake_resources["client"] formatted_customer_email = snowflake_resources["formatted_email"] @@ -1506,9 +1726,21 @@ def redshift_resources( @pytest.mark.integration_external @pytest.mark.integration_redshift +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_access_request_redshift( - redshift_resources, db, cache, policy, run_privacy_request_task + redshift_resources, + db, + cache, + policy, + run_privacy_request_task, + dsr_version, + request, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = redshift_resources["email"] customer_name = redshift_resources["name"] data = { @@ -1523,17 +1755,13 @@ def test_create_and_process_access_request_redshift( data, task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, ) - results = pr.get_results() - customer_table_key = ( - f"EN_{pr.id}__access_request__redshift_example_test_dataset:customer" - ) + results = pr.get_raw_access_results() + customer_table_key = f"redshift_example_test_dataset:customer" assert len(results[customer_table_key]) == 1 assert results[customer_table_key][0]["email"] == customer_email assert results[customer_table_key][0]["name"] == customer_name - address_table_key = ( - f"EN_{pr.id}__access_request__redshift_example_test_dataset:address" - ) + address_table_key = f"redshift_example_test_dataset:address" city = redshift_resources["city"] state = redshift_resources["state"] @@ -1546,6 +1774,10 @@ def test_create_and_process_access_request_redshift( @pytest.mark.integration_external @pytest.mark.integration_redshift +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_erasure_request_redshift( redshift_example_test_dataset_config, redshift_resources, @@ -1553,8 +1785,12 @@ def test_create_and_process_erasure_request_redshift( db, cache, erasure_policy, + dsr_version, + request, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = redshift_resources["email"] data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -1619,13 +1855,21 @@ def test_create_and_process_erasure_request_redshift( @pytest.mark.integration_external @pytest.mark.integration_bigquery +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_access_request_bigquery( bigquery_resources, db, cache, policy, + dsr_version, + request, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = bigquery_resources["email"] customer_name = bigquery_resources["name"] data = { @@ -1640,17 +1884,13 @@ def test_create_and_process_access_request_bigquery( data, task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, ) - results = pr.get_results() - customer_table_key = ( - f"EN_{pr.id}__access_request__bigquery_example_test_dataset:customer" - ) + results = pr.get_raw_access_results() + customer_table_key = f"bigquery_example_test_dataset:customer" assert len(results[customer_table_key]) == 1 assert results[customer_table_key][0]["email"] == customer_email assert results[customer_table_key][0]["name"] == customer_name - address_table_key = ( - f"EN_{pr.id}__access_request__bigquery_example_test_dataset:address" - ) + address_table_key = f"bigquery_example_test_dataset:address" city = bigquery_resources["city"] state = bigquery_resources["state"] @@ -1663,15 +1903,23 @@ def test_create_and_process_access_request_bigquery( @pytest.mark.integration_external @pytest.mark.integration_bigquery +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_erasure_request_bigquery( bigquery_example_test_dataset_config, bigquery_resources, integration_config: Dict[str, str], db, cache, + dsr_version, + request, erasure_policy, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = bigquery_resources["email"] data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -1862,53 +2110,6 @@ def test_run_webhooks_after_webhook( ) -@pytest.mark.integration_postgres -@pytest.mark.integration -@mock.patch( - "fides.api.service.privacy_request.request_runner_service.run_access_request" -) -@mock.patch("fides.api.models.privacy_request.PrivacyRequest.trigger_policy_webhook") -def test_privacy_request_log_failure( - _, - run_access_request_mock, - postgres_example_test_dataset_config_read_access, - postgres_integration_db, - db, - cache, - policy, - policy_pre_execution_webhooks, - policy_post_execution_webhooks, - run_privacy_request_task, -): - run_access_request_mock.side_effect = KeyError("Test error") - customer_email = "customer-1@example.com" - data = { - "requested_at": "2021-08-30T16:09:37.359Z", - "policy_key": policy.key, - "identity": {"email": customer_email}, - } - - with mock.patch( - "fides.api.service.privacy_request.request_runner_service.fideslog_graph_failure" - ) as mock_log_event: - pr = get_privacy_request_results( - db, - policy, - run_privacy_request_task, - data, - ) - sent_event = mock_log_event.call_args.args[0] - assert sent_event.docker is True - assert sent_event.event == "privacy_request_execution_failure" - assert sent_event.event_created_at is not None - - assert sent_event.local_host is None - assert sent_event.endpoint is None - assert sent_event.status_code == 500 - assert sent_event.error == "KeyError" - assert sent_event.extra_data == {"privacy_request": pr.id} - - class TestPrivacyRequestsEmailNotifications: @pytest.fixture(scope="function") def privacy_request_complete_email_notification_enabled(self, db): @@ -1922,6 +2123,10 @@ def privacy_request_complete_email_notification_enabled(self, db): @pytest.mark.integration_postgres @pytest.mark.integration + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) @mock.patch( "fides.api.service.privacy_request.request_runner_service.dispatch_message" ) @@ -1936,9 +2141,13 @@ def test_email_complete_send_erasure( erasure_policy, read_connection_config, messaging_config, + dsr_version, + request, privacy_request_complete_email_notification_enabled, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-1@example.com" data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -1962,6 +2171,10 @@ def test_email_complete_send_erasure( "fides.api.service.privacy_request.request_runner_service.dispatch_message" ) @mock.patch("fides.api.service.privacy_request.request_runner_service.upload") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) def test_email_complete_send_access( self, upload_mock, @@ -1976,7 +2189,11 @@ def test_email_complete_send_access( messaging_config, privacy_request_complete_email_notification_enabled, run_privacy_request_task, + dsr_version, + request, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + upload_mock.return_value = "http://www.data-download-url" customer_email = "customer-1@example.com" data = { @@ -1997,6 +2214,10 @@ def test_email_complete_send_access( @pytest.mark.integration_postgres @pytest.mark.integration + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) @mock.patch( "fides.api.service.privacy_request.request_runner_service.dispatch_message" ) @@ -2013,9 +2234,13 @@ def test_email_complete_send_access_and_erasure( access_and_erasure_policy, read_connection_config, messaging_config, + dsr_version, + request, privacy_request_complete_email_notification_enabled, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + upload_mock.return_value = "http://www.data-download-url" download_time_in_days = "5" customer_email = "customer-1@example.com" @@ -2063,6 +2288,10 @@ def test_email_complete_send_access_and_erasure( "fides.api.service.messaging.message_dispatch_service._mailgun_dispatcher" ) @mock.patch("fides.api.service.privacy_request.request_runner_service.upload") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) def test_email_complete_send_access_no_messaging_config( self, upload_mock, @@ -2074,9 +2303,13 @@ def test_email_complete_send_access_no_messaging_config( generate_auth_header, policy, read_connection_config, + dsr_version, + request, privacy_request_complete_email_notification_enabled, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + upload_mock.return_value = "http://www.data-download-url" customer_email = "customer-1@example.com" data = { @@ -2103,6 +2336,10 @@ def test_email_complete_send_access_no_messaging_config( "fides.api.service.messaging.message_dispatch_service._mailgun_dispatcher" ) @mock.patch("fides.api.service.privacy_request.request_runner_service.upload") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) def test_email_complete_send_access_no_email_identity( self, upload_mock, @@ -2116,7 +2353,11 @@ def test_email_complete_send_access_no_email_identity( read_connection_config, privacy_request_complete_email_notification_enabled, run_privacy_request_task, + dsr_version, + request, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + upload_mock.return_value = "http://www.data-download-url" data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -2139,6 +2380,10 @@ def test_email_complete_send_access_no_email_identity( class TestPrivacyRequestsManualWebhooks: @mock.patch("fides.api.service.privacy_request.request_runner_service.upload") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) def test_privacy_request_needs_manual_input_key_in_cache( self, mock_upload, @@ -2147,7 +2392,11 @@ def test_privacy_request_needs_manual_input_key_in_cache( policy, run_privacy_request_task, db, + dsr_version, + request, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-1@example.com" data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -2166,7 +2415,13 @@ def test_privacy_request_needs_manual_input_key_in_cache( assert not mock_upload.called @mock.patch("fides.api.service.privacy_request.request_runner_service.upload") - @mock.patch("fides.api.service.privacy_request.request_runner_service.run_erasure") + @mock.patch( + "fides.api.service.privacy_request.request_runner_service.erasure_runner" + ) + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) def test_manual_input_required_for_erasure_only_policies( self, mock_erasure, @@ -2174,10 +2429,14 @@ def test_manual_input_required_for_erasure_only_policies( integration_manual_webhook_config, access_manual_webhook, erasure_policy, + dsr_version, + request, run_privacy_request_task, db, ): """Manual inputs are not tied to policies, but should still hold up a request even for erasure requests.""" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = "customer-1@example.com" data = { "requested_at": "2021-08-30T16:09:37.359Z", @@ -2197,6 +2456,10 @@ def test_manual_input_required_for_erasure_only_policies( assert not mock_erasure.called @mock.patch("fides.api.service.privacy_request.request_runner_service.upload") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) def test_pass_on_manually_added_input( self, mock_upload, @@ -2206,8 +2469,13 @@ def test_pass_on_manually_added_input( run_privacy_request_task, privacy_request_requires_input: PrivacyRequest, db, + dsr_version, + request, cached_access_input, ): + mock_upload.return_value = "http://www.data-download-url" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + run_privacy_request_task.delay(privacy_request_requires_input.id).get( timeout=PRIVACY_REQUEST_TASK_TIMEOUT ) @@ -2221,6 +2489,10 @@ def test_pass_on_manually_added_input( } @mock.patch("fides.api.service.privacy_request.request_runner_service.upload") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) def test_pass_on_partial_manually_added_input( self, mock_upload, @@ -2228,9 +2500,14 @@ def test_pass_on_partial_manually_added_input( access_manual_webhook, policy, run_privacy_request_task, + dsr_version, + request, privacy_request_requires_input: PrivacyRequest, db, ): + mock_upload.return_value = "http://www.data-download-url" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + privacy_request_requires_input.cache_manual_webhook_access_input( access_manual_webhook, {"email": "customer-1@example.com"}, @@ -2249,6 +2526,10 @@ def test_pass_on_partial_manually_added_input( ] } + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) @mock.patch("fides.api.service.privacy_request.request_runner_service.upload") def test_pass_on_empty_confirmed_input( self, @@ -2259,7 +2540,12 @@ def test_pass_on_empty_confirmed_input( run_privacy_request_task, privacy_request_requires_input: PrivacyRequest, db, + dsr_version, + request, ): + mock_upload.return_value = "http://www.data-download-url" + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + privacy_request_requires_input.cache_manual_webhook_access_input( access_manual_webhook, {}, @@ -2299,17 +2585,25 @@ def test_build_consent_dataset_graph( class TestConsentEmailStep: + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) def test_privacy_request_completes_if_no_consent_email_send_needed( self, db, privacy_request_with_consent_policy, run_privacy_request_task, + dsr_version, + request, sovrn_email_connection_config, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + run_privacy_request_task.delay( privacy_request_id=privacy_request_with_consent_policy.id, from_step=None, - ).get(timeout=PRIVACY_REQUEST_TASK_TIMEOUT) + ).get(timeout=5) db.refresh(privacy_request_with_consent_policy) assert ( privacy_request_with_consent_policy.status == PrivacyRequestStatus.complete @@ -2326,12 +2620,20 @@ def test_privacy_request_completes_if_no_consent_email_send_needed( ] @pytest.mark.usefixtures("sovrn_email_connection_config") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) def test_privacy_request_is_put_in_awaiting_email_send_status_old_workflow( self, db, privacy_request_with_consent_policy, run_privacy_request_task, + dsr_version, + request, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + identity = Identity(email="customer_1#@example.com", ljt_readerID="12345") privacy_request_with_consent_policy.cache_identity(identity) privacy_request_with_consent_policy.consent_preferences = [ @@ -2351,13 +2653,21 @@ def test_privacy_request_is_put_in_awaiting_email_send_status_old_workflow( assert privacy_request_with_consent_policy.awaiting_email_send_at is not None @pytest.mark.usefixtures("sovrn_email_connection_config") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) def test_privacy_request_is_put_in_awaiting_email_new_workflow( self, db, privacy_request_with_consent_policy, run_privacy_request_task, + dsr_version, + request, privacy_preference_history, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + identity = Identity(email="customer_1#@example.com", ljt_readerID="12345") privacy_request_with_consent_policy.cache_identity(identity) privacy_preference_history.privacy_request_id = ( @@ -2467,15 +2777,23 @@ def test_needs_batch_email_send_system_and_notice_data_use_mismatch( ) @pytest.mark.usefixtures("sovrn_email_connection_config") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) def test_skipped_batch_email_send_updates_privacy_preferences_with_system_status( self, db, privacy_request_with_consent_policy, system, + dsr_version, + request, privacy_preference_history_us_ca_provide, sovrn_email_connection_config, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + sovrn_email_connection_config.system_id = system.id sovrn_email_connection_config.save(db) @@ -2503,9 +2821,20 @@ def test_skipped_batch_email_send_updates_privacy_preferences_with_system_status ].affected_system_status == {system.fides_key: "skipped"} @pytest.mark.usefixtures("sovrn_email_connection_config") + @pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], + ) def test_needs_batch_email_send_new_workflow( - self, db, privacy_request_with_consent_policy, privacy_preference_history + self, + db, + privacy_request_with_consent_policy, + privacy_preference_history, + dsr_version, + request, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + privacy_preference_history.privacy_request_id = ( privacy_request_with_consent_policy.id ) @@ -2615,12 +2944,20 @@ def dynamodb_resources( @pytest.mark.integration_external @pytest.mark.integration_dynamodb +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_empty_access_request_dynamodb( db, cache, policy, + dsr_version, + request, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + data = { "requested_at": "2021-08-30T16:09:37.359Z", "policy_key": policy.key, @@ -2635,20 +2972,28 @@ def test_create_and_process_empty_access_request_dynamodb( task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, ) # Here the results should be empty as no data will be located for that identity - results = pr.get_results() + results = pr.get_raw_access_results() pr.delete(db=db) assert results == {} @pytest.mark.integration_external @pytest.mark.integration_dynamodb +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_access_request_dynamodb( dynamodb_resources, db, cache, policy, run_privacy_request_task, + dsr_version, + request, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = dynamodb_resources["email"] customer_name = dynamodb_resources["name"] customer_id = dynamodb_resources["customer_id"] @@ -2665,14 +3010,10 @@ def test_create_and_process_access_request_dynamodb( data, task_timeout=PRIVACY_REQUEST_TASK_TIMEOUT_EXTERNAL, ) - results = pr.get_results() - customer_table_key = ( - f"EN_{pr.id}__access_request__dynamodb_example_test_dataset:customer" - ) - address_table_key = ( - f"EN_{pr.id}__access_request__dynamodb_example_test_dataset:address" - ) - login_table_key = f"EN_{pr.id}__access_request__dynamodb_example_test_dataset:login" + results = pr.get_raw_access_results() + customer_table_key = f"dynamodb_example_test_dataset:customer" + address_table_key = f"dynamodb_example_test_dataset:address" + login_table_key = f"dynamodb_example_test_dataset:login" assert len(results[customer_table_key]) == 1 assert len(results[address_table_key]) == 1 assert len(results[login_table_key]) == 2 @@ -2687,6 +3028,10 @@ def test_create_and_process_access_request_dynamodb( @pytest.mark.integration_external @pytest.mark.integration_dynamodb +@pytest.mark.parametrize( + "dsr_version", + ["use_dsr_3_0", "use_dsr_2_0"], +) def test_create_and_process_erasure_request_dynamodb( dynamodb_example_test_dataset_config, dynamodb_resources, @@ -2694,8 +3039,12 @@ def test_create_and_process_erasure_request_dynamodb( db, cache, erasure_policy, + dsr_version, + request, run_privacy_request_task, ): + request.getfixturevalue(dsr_version) # REQUIRED to test both DSR 3.0 and 2.0 + customer_email = dynamodb_resources["email"] dynamodb_client = dynamodb_resources["client"] customer_id = dynamodb_resources["customer_id"] diff --git a/tests/ops/service/privacy_request/test_request_service.py b/tests/ops/service/privacy_request/test_request_service.py index 8e889fc8de..1832957165 100644 --- a/tests/ops/service/privacy_request/test_request_service.py +++ b/tests/ops/service/privacy_request/test_request_service.py @@ -1,15 +1,26 @@ +import json +import time from datetime import datetime +from unittest import mock import pytest from httpx import HTTPStatusError from fides.api.cryptography.cryptographic_util import str_to_b64_str from fides.api.db.seed import create_or_update_parent_user -from fides.api.models.privacy_request import PrivacyRequest, PrivacyRequestStatus +from fides.api.models.privacy_request import ( + ExecutionLogStatus, + PrivacyRequest, + PrivacyRequestStatus, +) +from fides.api.schemas.policy import ActionType from fides.api.service.privacy_request.request_service import ( build_required_privacy_request_kwargs, + poll_for_exited_privacy_request_tasks, poll_server_for_completion, + remove_saved_dsr_data, ) +from fides.api.util.cache import CustomJSONEncoder from fides.common.api.v1.urn_registry import LOGIN, V1_URL_PREFIX from fides.config import CONFIG @@ -113,6 +124,199 @@ async def test_poll_server_for_completion_non_200(async_api_client, db, policy): ) +class TestPollForExitedPrivacyRequests: + def test_no_request_tasks(self, db, privacy_request): + errored_prs = poll_for_exited_privacy_request_tasks.delay().get() + assert errored_prs == set() + + db.refresh(privacy_request) + assert privacy_request.status == PrivacyRequestStatus.in_processing + + @pytest.mark.usefixtures("request_task") + def test_request_tasks_still_in_progress(self, db, privacy_request): + errored_prs = poll_for_exited_privacy_request_tasks.delay().get() + assert errored_prs == set() + + db.refresh(privacy_request) + assert privacy_request.status == PrivacyRequestStatus.in_processing + + def test_request_tasks_all_exited_and_some_errored( + self, db, privacy_request, request_task + ): + # Put all tasks in an exited state - completed, errored, or skipped + root_task = privacy_request.get_root_task_by_action(ActionType.access) + assert root_task.status == ExecutionLogStatus.complete + request_task.update_status(db, ExecutionLogStatus.skipped) + terminator_task = privacy_request.get_terminate_task_by_action( + ActionType.access + ) + terminator_task.update_status(db, ExecutionLogStatus.error) + + errored_prs = poll_for_exited_privacy_request_tasks.delay().get() + assert errored_prs == {privacy_request.id} + + db.refresh(privacy_request) + assert privacy_request.status == PrivacyRequestStatus.error + + def test_request_tasks_all_exited_none_errored( + self, db, privacy_request, request_task + ): + # Put all tasks in an exited state - but none are errored. + # This task does not flip the status of the overall privacy request in that case + root_task = privacy_request.get_root_task_by_action(ActionType.access) + assert root_task.status == ExecutionLogStatus.complete + request_task.update_status(db, ExecutionLogStatus.skipped) + terminator_task = privacy_request.get_terminate_task_by_action( + ActionType.access + ) + terminator_task.update_status(db, ExecutionLogStatus.complete) + + errored_prs = poll_for_exited_privacy_request_tasks.delay().get() + assert errored_prs == set() + + db.refresh(privacy_request) + assert privacy_request.status == PrivacyRequestStatus.in_processing + + def test_access_tasks_errored_erasure_tasks_pending( + self, db, privacy_request, request_task, erasure_request_task + ): + """Request tasks of different action types are considered separately. + If access tasks have errored but erasure tasks are pending, the access step itself + causes the privacy request as whole to get marked as error. + """ + # Erasure tasks still pending - these were created at the same time as the + # access tasks but can't run until the access section is finished + assert erasure_request_task.status == ExecutionLogStatus.pending + + # Access tasks have errored + root_task = privacy_request.get_root_task_by_action(ActionType.access) + assert root_task.status == ExecutionLogStatus.complete + request_task.update_status(db, ExecutionLogStatus.error) + terminator_task = privacy_request.get_terminate_task_by_action( + ActionType.access + ) + terminator_task.update_status(db, ExecutionLogStatus.error) + + errored_prs = poll_for_exited_privacy_request_tasks.delay().get() + assert errored_prs == {privacy_request.id} + + db.refresh(privacy_request) + assert privacy_request.status == PrivacyRequestStatus.error + + @pytest.mark.usefixtures("request_task") + def test_access_tasks_complete_erasure_tasks_errored( + self, db, privacy_request, erasure_request_task + ): + """Tasks of different action types are considered separately. If all access tasks + have completed, but erasure tasks have an error, the entire privacy request will be marked as error + """ + for rq in privacy_request.request_tasks: + rq.update_status(db, ExecutionLogStatus.complete) + + erasure_request_task.update_status(db, ExecutionLogStatus.error) + + errored_prs = poll_for_exited_privacy_request_tasks.delay().get() + assert errored_prs == {privacy_request.id} + + db.refresh(privacy_request) + assert privacy_request.status == PrivacyRequestStatus.error + + +@pytest.fixture(scope="function") +def very_short_request_task_expiration(): + original_value: float = CONFIG.execution.request_task_ttl + CONFIG.execution.request_task_ttl = ( + 0.01 # Set redis cache to expire very quickly for testing purposes + ) + yield CONFIG + CONFIG.execution.request_task_ttl = original_value + + +@pytest.fixture(scope="function") +def very_short_redis_cache_expiration(): + original_value: float = CONFIG.redis.default_ttl_seconds + CONFIG.redis.default_ttl_seconds = ( + 0.01 # Set redis cache to expire very quickly for testing purposes + ) + yield CONFIG + CONFIG.redis.default_ttl_seconds = original_value + + +class TestRemoveSavedCustomerData: + @pytest.mark.usefixtures( + "very_short_redis_cache_expiration", "very_short_request_task_expiration" + ) + def test_no_request_tasks(self, db, privacy_request): + assert not privacy_request.request_tasks.count() + time.sleep(1) + + # Mainly asserting this runs without error + remove_saved_dsr_data.delay().get() + + db.refresh(privacy_request) + assert not privacy_request.request_tasks.count() + + @pytest.mark.usefixtures( + "very_short_redis_cache_expiration", + "very_short_request_task_expiration", + "request_task", + ) + def test_privacy_request_incomplete(self, db, privacy_request): + """Incomplete Privacy Requests are not cleaned up""" + assert privacy_request.status == PrivacyRequestStatus.in_processing + privacy_request.save(db) + + privacy_request.filtered_final_upload = json.dumps( + {"rule_key": {"test_dataset:test_collection": [{"id": 1, "name": "Jane"}]}}, + cls=CustomJSONEncoder, + ) + privacy_request.access_result_urls = {"access_result_urls": ["www.example.com"]} + privacy_request.save(db) + + assert privacy_request.request_tasks.count() + time.sleep(1) + + remove_saved_dsr_data.delay().get() + + db.refresh(privacy_request) + assert privacy_request.filtered_final_upload is not None + assert privacy_request.access_result_urls is not None + assert privacy_request.request_tasks.count() + + @pytest.mark.usefixtures( + "very_short_redis_cache_expiration", + "very_short_request_task_expiration", + "request_task", + ) + def test_customer_data_removed_from_old_request_tasks_and_privacy_requests( + self, db, privacy_request, loguru_caplog + ): + privacy_request.status = PrivacyRequestStatus.complete + privacy_request.save(db) + + privacy_request.filtered_final_upload = json.dumps( + {"rule_key": {"test_dataset:test_collection": [{"id": 1, "name": "Jane"}]}}, + cls=CustomJSONEncoder, + ) + privacy_request.access_result_urls = {"access_result_urls": ["www.example.com"]} + privacy_request.save(db) + + assert privacy_request.request_tasks.count() + time.sleep(1) + + remove_saved_dsr_data.delay().get() + + db.refresh(privacy_request) + assert privacy_request.filtered_final_upload is None + assert privacy_request.access_result_urls is None + assert not privacy_request.request_tasks.count() + + assert ( + "Deleted 3 expired request tasks via DSR Data Removal Task." + in loguru_caplog.text + ) + + class TestBuildPrivacyRequestRequiredKwargs: def test_build_required_privacy_request_kwargs_authenticated(self): resp = build_required_privacy_request_kwargs( diff --git a/tests/ops/task/test_create_request_tasks.py b/tests/ops/task/test_create_request_tasks.py new file mode 100644 index 0000000000..be733b6f28 --- /dev/null +++ b/tests/ops/task/test_create_request_tasks.py @@ -0,0 +1,1660 @@ +from datetime import datetime +from unittest import mock +from unittest.mock import MagicMock + +import pytest +from fideslang import Dataset + +from fides.api.common_exceptions import TraversalError +from fides.api.graph.config import ROOT_COLLECTION_ADDRESS, TERMINATOR_ADDRESS +from fides.api.graph.graph import DatasetGraph +from fides.api.graph.traversal import Traversal, TraversalNode +from fides.api.models.datasetconfig import convert_dataset_to_graph +from fides.api.models.privacy_request import ( + ExecutionLogStatus, + PrivacyRequestStatus, + RequestTask, +) +from fides.api.schemas.policy import ActionType +from fides.api.task.create_request_tasks import ( + collect_tasks_fn, + get_existing_ready_tasks, + persist_initial_erasure_request_tasks, + persist_new_access_request_tasks, + persist_new_consent_request_tasks, + run_access_request, + run_consent_request, + run_erasure_request, + update_erasure_tasks_with_access_data, +) +from fides.api.task.execute_request_tasks import run_access_node +from fides.api.task.graph_task import build_consent_dataset_graph +from fides.config import CONFIG +from tests.conftest import wait_for_tasks_to_complete +from tests.ops.task.traversal_data import combined_mongo_postgresql_graph + +from ..graph.graph_test_util import erasure_policy, field + +payment_card_serialized_collection = { + "name": "payment_card", + "after": [], + "fields": [ + { + "name": "billing_address_id", + "length": None, + "identity": None, + "is_array": False, + "read_only": None, + "references": [["postgres_example_test_dataset:address:id", "to"]], + "primary_key": False, + "data_categories": ["system.operations"], + "data_type_converter": "None", + "return_all_elements": None, + }, + { + "name": "ccn", + "length": None, + "identity": None, + "is_array": False, + "read_only": None, + "references": [], + "primary_key": False, + "data_categories": ["user.financial.bank_account"], + "data_type_converter": "None", + "return_all_elements": None, + }, + { + "name": "code", + "length": None, + "identity": None, + "is_array": False, + "read_only": None, + "references": [], + "primary_key": False, + "data_categories": ["user.financial"], + "data_type_converter": "None", + "return_all_elements": None, + }, + { + "name": "customer_id", + "length": None, + "identity": None, + "is_array": False, + "read_only": None, + "references": [["postgres_example_test_dataset:customer:id", "from"]], + "primary_key": False, + "data_categories": ["user.unique_id"], + "data_type_converter": "None", + "return_all_elements": None, + }, + { + "name": "id", + "length": None, + "identity": None, + "is_array": False, + "read_only": None, + "references": [], + "primary_key": True, + "data_categories": ["system.operations"], + "data_type_converter": "None", + "return_all_elements": None, + }, + { + "name": "name", + "length": None, + "identity": None, + "is_array": False, + "read_only": None, + "references": [], + "primary_key": False, + "data_categories": ["user.financial"], + "data_type_converter": "None", + "return_all_elements": None, + }, + { + "name": "preferred", + "length": None, + "identity": None, + "is_array": False, + "read_only": None, + "references": [], + "primary_key": False, + "data_categories": ["user"], + "data_type_converter": "None", + "return_all_elements": None, + }, + ], + "erase_after": [], + "grouped_inputs": [], + "skip_processing": False, +} + +payment_card_serialized_traversal_details = { + "input_keys": ["postgres_example_test_dataset:customer"], + "incoming_edges": [ + [ + "postgres_example_test_dataset:customer:id", + "postgres_example_test_dataset:payment_card:customer_id", + ] + ], + "outgoing_edges": [ + [ + "postgres_example_test_dataset:payment_card:billing_address_id", + "postgres_example_test_dataset:address:id", + ] + ], + "dataset_connection_key": "my_postgres_db_1", +} + + +class TestPersistAccessRequestTasks: + def test_persist_access_tasks(self, db, privacy_request, postgres_dataset_graph): + """Test the RequestTasks that are generated for an access request""" + identity = {"email": "customer-1@example.com"} + traversal: Traversal = Traversal(postgres_dataset_graph, identity) + traversal_nodes = {} + end_nodes = traversal.traverse(traversal_nodes, collect_tasks_fn) + + ready_tasks = persist_new_access_request_tasks( + db, + privacy_request, + traversal, + traversal_nodes, + end_nodes, + postgres_dataset_graph, + ) + assert len(ready_tasks) == 1 + + root_task = ready_tasks[0] + # The Root Task is the only one ready to be queued - assert key details + assert root_task.privacy_request_id == privacy_request.id + assert root_task.action_type == ActionType.access + assert root_task.collection_address == "__ROOT__:__ROOT__" + assert root_task.dataset_name == "__ROOT__" + assert root_task.collection_name == "__ROOT__" + # We just create the root task in the completed state automatically + assert root_task.status == ExecutionLogStatus.complete + assert root_task.upstream_tasks == [] + # These are the downstream data dependencies + assert root_task.downstream_tasks == [ + "postgres_example_test_dataset:customer", + "postgres_example_test_dataset:employee", + "postgres_example_test_dataset:report", + "postgres_example_test_dataset:service_request", + "postgres_example_test_dataset:visit", + ] + # All nodes can be reached by the root node + assert ( + len(root_task.all_descendant_tasks) + == 12 + == privacy_request.access_tasks.count() - 1 + ) + # Identity data is saved as encrypted access data - + assert root_task.access_data == '[{"email": "customer-1@example.com"}]' + assert root_task.get_decoded_access_data() == [ + {"email": "customer-1@example.com"} + ] + # ARTIFICIAL NODES don't have collections or traversal details + assert root_task.collection is None + assert root_task.traversal_details == {} + assert root_task.is_root_task + assert not root_task.is_terminator_task + + # Assert key details on terminator task + terminator_task = privacy_request.get_terminate_task_by_action( + ActionType.access + ) + + assert terminator_task.action_type == ActionType.access + assert terminator_task.collection_name == "__TERMINATE__" + assert terminator_task.dataset_name == "__TERMINATE__" + assert terminator_task.status == ExecutionLogStatus.pending + assert terminator_task.upstream_tasks == [ + "postgres_example_test_dataset:address", + "postgres_example_test_dataset:login", + "postgres_example_test_dataset:product", + "postgres_example_test_dataset:report", + "postgres_example_test_dataset:service_request", + "postgres_example_test_dataset:visit", + ] + + assert terminator_task.downstream_tasks == [] + assert terminator_task.all_descendant_tasks == [] + assert terminator_task.access_data == "[]" + # ARTIFICIAL NODES don't have collections or traversal details + assert terminator_task.collection is None + assert terminator_task.traversal_details == {} + assert not terminator_task.is_root_task + assert terminator_task.is_terminator_task + + # Assert key details on payment card task + payment_card_task = privacy_request.access_tasks.filter( + RequestTask.collection_address + == "postgres_example_test_dataset:payment_card" + ).first() + assert payment_card_task.action_type == ActionType.access + assert payment_card_task.collection_name == "payment_card" + assert payment_card_task.dataset_name == "postgres_example_test_dataset" + assert payment_card_task.status == ExecutionLogStatus.pending + assert payment_card_task.upstream_tasks == [ + "postgres_example_test_dataset:customer" + ] + assert payment_card_task.downstream_tasks == [ + "postgres_example_test_dataset:address" + ] + assert payment_card_task.all_descendant_tasks == [ + "__TERMINATE__:__TERMINATE__", + "postgres_example_test_dataset:address", + ] + assert payment_card_task.access_data == "[]" + assert payment_card_task.collection == payment_card_serialized_collection + assert ( + payment_card_task.traversal_details + == payment_card_serialized_traversal_details + ) + assert not payment_card_task.is_root_task + assert not payment_card_task.is_terminator_task + + def test_persist_access_tasks_with_object_fields_in_collection( + self, db, privacy_request, postgres_and_mongo_dataset_graph + ): + identity = {"email": "customer-1@example.com"} + + traversal: Traversal = Traversal(postgres_and_mongo_dataset_graph, identity) + traversal_nodes = {} + end_nodes = traversal.traverse(traversal_nodes, collect_tasks_fn) + + ready_tasks = persist_new_access_request_tasks( + db, + privacy_request, + traversal, + traversal_nodes, + end_nodes, + postgres_and_mongo_dataset_graph, + ) + assert len(ready_tasks) == 1 + + customer_profile = privacy_request.access_tasks.filter( + RequestTask.collection_address == "mongo_test:internal_customer_profile" + ).first() + + # Object fields serialized correctly + assert customer_profile.collection == { + "name": "internal_customer_profile", + "after": [], + "fields": [ + { + "name": "_id", + "length": None, + "identity": None, + "is_array": False, + "read_only": None, + "references": [], + "primary_key": True, + "data_categories": ["system.operations"], + "data_type_converter": "object_id", + "return_all_elements": None, + }, + { + "name": "customer_identifiers", + "fields": { + "internal_id": { + "name": "internal_id", + "length": None, + "identity": None, + "is_array": False, + "read_only": None, + "references": [ + [ + "mongo_test:customer_feedback:customer_information.internal_customer_id", + "from", + ] + ], + "primary_key": False, + "data_categories": None, + "data_type_converter": "string", + "return_all_elements": None, + }, + "derived_phone": { + "name": "derived_phone", + "length": None, + "identity": "phone_number", + "is_array": True, + "read_only": None, + "references": [], + "primary_key": False, + "data_categories": ["user"], + "data_type_converter": "string", + "return_all_elements": True, + }, + "derived_emails": { + "name": "derived_emails", + "length": None, + "identity": "email", + "is_array": True, + "read_only": None, + "references": [], + "primary_key": False, + "data_categories": ["user"], + "data_type_converter": "string", + "return_all_elements": None, + }, + }, + "length": None, + "identity": None, + "is_array": False, + "read_only": None, + "references": [], + "primary_key": False, + "data_categories": None, + "data_type_converter": "object", + "return_all_elements": None, + }, + { + "name": "derived_interests", + "length": None, + "identity": None, + "is_array": True, + "read_only": None, + "references": [], + "primary_key": False, + "data_categories": ["user"], + "data_type_converter": "string", + "return_all_elements": None, + }, + ], + "erase_after": [], + "grouped_inputs": [], + "skip_processing": False, + } + + def test_no_collections(self, db, privacy_request): + identity = {"email": "customer-1@example.com"} + + traversal: Traversal = Traversal(DatasetGraph(), identity) + traversal_nodes = {} + end_nodes = traversal.traverse(traversal_nodes, collect_tasks_fn) + + ready_tasks = persist_new_access_request_tasks( + db, + privacy_request, + traversal, + traversal_nodes, + end_nodes, + DatasetGraph(), + ) + + assert len(ready_tasks) == 1 + db.refresh(privacy_request) + assert len(privacy_request.access_tasks.all()) == 2 + assert privacy_request.access_tasks[0].is_root_task + assert privacy_request.access_tasks[0].upstream_tasks == [] + assert privacy_request.access_tasks[0].downstream_tasks == [ + TERMINATOR_ADDRESS.value + ] + + assert privacy_request.access_tasks[1].is_terminator_task + assert privacy_request.access_tasks[1].upstream_tasks == [ + ROOT_COLLECTION_ADDRESS.value + ] + assert privacy_request.access_tasks[1].downstream_tasks == [] + + @mock.patch( + "fides.api.task.create_request_tasks.queue_request_task", + ) + def test_run_access_request_no_request_tasks_existing( + self, run_access_node_mock, db, privacy_request, policy + ): + """Request tasks created by run_access_request and the root task is queued""" + ready = run_access_request( + privacy_request, + policy, + DatasetGraph(), + [], + {"email": "customer-4@example.com"}, + db, + privacy_request_proceed=False, + ) + + assert len(ready) == 1 + root_task = ready[0] + assert root_task.is_root_task + + assert run_access_node_mock.called + run_access_node_mock.assert_called_with(root_task, False) + + @mock.patch( + "fides.api.task.create_request_tasks.queue_request_task", + ) + def test_reprocess_access_request_with_existing_request_tasks( + self, run_access_node_mock, request_task, db, privacy_request, policy + ): + assert privacy_request.access_tasks.count() == 3 + + ready = run_access_request( + privacy_request, + policy, + DatasetGraph(), + [], + {"email": "customer-4@example.com"}, + db, + privacy_request_proceed=False, + ) + + assert len(ready) == 1 + ready_task = ready[0] + assert ready_task == request_task + assert not ready_task.is_root_task + assert ready_task.status == ExecutionLogStatus.pending + + assert run_access_node_mock.called + run_access_node_mock.assert_called_with(request_task, False) + + +class TestPersistErasureRequestTasks: + def test_persist_initial_erasure_request_tasks( + self, db, privacy_request, postgres_dataset_graph + ): + """Test the RequestTasks that are generated for an erasure graph + These are generated at the same time as the access graph, but are not runnable + until the access graph is completed in full + """ + identity = {"email": "customer-1@example.com"} + traversal: Traversal = Traversal(postgres_dataset_graph, identity) + + traversal_nodes = {} + _ = traversal.traverse(traversal_nodes, collect_tasks_fn) + # Because the access graph completes in full first, getting all the data the erasure + # graph needs to build masking requests, the erasure graph can be run entirely + # in parallel. So the end nodes are all of the nodes. + erasure_end_nodes = list(postgres_dataset_graph.nodes.keys()) + + ready_tasks = persist_initial_erasure_request_tasks( + db, + privacy_request, + traversal_nodes, + erasure_end_nodes, + postgres_dataset_graph, + ) + assert ready_tasks == [] + + assert privacy_request.erasure_tasks.count() == 13 + + root_task = privacy_request.get_root_task_by_action(ActionType.erasure) + assert root_task.action_type == ActionType.erasure + assert root_task.privacy_request_id == privacy_request.id + assert root_task.collection_address == "__ROOT__:__ROOT__" + assert root_task.dataset_name == "__ROOT__" + assert root_task.collection_name == "__ROOT__" + assert root_task.status == ExecutionLogStatus.complete + assert root_task.upstream_tasks == [] + # Every node other than the terminate node is downstream of the root node + assert root_task.downstream_tasks == [ + "postgres_example_test_dataset:address", + "postgres_example_test_dataset:customer", + "postgres_example_test_dataset:employee", + "postgres_example_test_dataset:login", + "postgres_example_test_dataset:order_item", + "postgres_example_test_dataset:orders", + "postgres_example_test_dataset:payment_card", + "postgres_example_test_dataset:product", + "postgres_example_test_dataset:report", + "postgres_example_test_dataset:service_request", + "postgres_example_test_dataset:visit", + ] + # Every node that can be reached by the root node + assert ( + len(root_task.all_descendant_tasks) + == 12 + == privacy_request.erasure_tasks.count() - 1 + ) + assert root_task.access_data is None + assert root_task.data_for_erasures is None + assert root_task.get_decoded_access_data() == [] + # ARTIFICIAL NODES don't have collections or traversal details + assert root_task.collection is None + assert root_task.traversal_details == {} + assert root_task.is_root_task + assert not root_task.is_terminator_task + + # Assert key details on terminator task + terminator_task = privacy_request.get_terminate_task_by_action( + ActionType.erasure + ) + assert terminator_task.action_type == ActionType.erasure + assert terminator_task.collection_name == "__TERMINATE__" + assert terminator_task.dataset_name == "__TERMINATE__" + assert terminator_task.status == ExecutionLogStatus.pending + # Every node but the root node has the terminator task downstream of it + assert terminator_task.upstream_tasks == root_task.downstream_tasks + assert terminator_task.downstream_tasks == [] + assert terminator_task.all_descendant_tasks == [] + assert terminator_task.access_data is None + assert terminator_task.data_for_erasures is None + # ARTIFICIAL NODES don't have collections or traversal details + assert terminator_task.collection is None + assert terminator_task.traversal_details == {} + assert not terminator_task.is_root_task + assert terminator_task.is_terminator_task + + # Assert key details on payment card task + payment_card_task = privacy_request.erasure_tasks.filter( + RequestTask.collection_address + == "postgres_example_test_dataset:payment_card" + ).first() + assert payment_card_task.action_type == ActionType.erasure + assert payment_card_task.collection_name == "payment_card" + assert payment_card_task.dataset_name == "postgres_example_test_dataset" + assert payment_card_task.status == ExecutionLogStatus.pending + assert payment_card_task.upstream_tasks == ["__ROOT__:__ROOT__"] + assert payment_card_task.downstream_tasks == [ + "__TERMINATE__:__TERMINATE__", + ] + assert payment_card_task.all_descendant_tasks == [ + "__TERMINATE__:__TERMINATE__", + ] + assert payment_card_task.access_data is None + assert payment_card_task.data_for_erasures is None + # Even though the downstream task is just the terminate node and the upstream + # task is just the root node, it's upstream and downstream edges are still + # based on data dependencies + assert payment_card_task.collection == payment_card_serialized_collection + assert ( + payment_card_task.traversal_details + == payment_card_serialized_traversal_details + ) + + assert not payment_card_task.is_root_task + assert not payment_card_task.is_terminator_task + + @pytest.mark.timeout(5) + @pytest.mark.integration + @pytest.mark.integration_postgres + def test_update_erasure_tasks_with_access_data( + self, db, privacy_request, example_datasets, integration_postgres_config + ): + """Test that erasure tasks are updated with the corresponding erasure data collected + from the access task""" + dataset = Dataset(**example_datasets[0]) + graph = convert_dataset_to_graph(dataset, integration_postgres_config.key) + dataset_graph = DatasetGraph(*[graph]) + + identity = {"email": "customer-1@example.com"} + traversal: Traversal = Traversal(dataset_graph, identity) + + traversal_nodes = {} + access_end_nodes = traversal.traverse(traversal_nodes, collect_tasks_fn) + erasure_end_nodes = list(dataset_graph.nodes.keys()) + + ready_tasks = persist_new_access_request_tasks( + db, + privacy_request, + traversal, + traversal_nodes, + access_end_nodes, + dataset_graph, + ) + + persist_initial_erasure_request_tasks( + db, + privacy_request, + traversal_nodes, + erasure_end_nodes, + dataset_graph, + ) + + run_access_node.delay( + privacy_request.id, ready_tasks[0].id, privacy_request_proceed=False + ) + wait_for_tasks_to_complete(db, privacy_request, ActionType.access) + + update_erasure_tasks_with_access_data(db, privacy_request) + payment_card_task = privacy_request.erasure_tasks.filter( + RequestTask.collection_address + == "postgres_example_test_dataset:payment_card" + ).first() + + # access data collected for masking was added to this erasure node of the same address + assert ( + payment_card_task.data_for_erasures + == '[{"billing_address_id": 1, "ccn": 123456789, "code": 321, "customer_id": 1, "id": "pay_aaa-aaa", "name": "Example Card 1", "preferred": true}]' + ) + assert payment_card_task.get_decoded_data_for_erasures() == [ + { + "billing_address_id": 1, + "ccn": 123456789, + "code": 321, + "customer_id": 1, + "id": "pay_aaa-aaa", + "name": "Example Card 1", + "preferred": True, + } + ] + assert payment_card_task.status == ExecutionLogStatus.pending + + address_task = privacy_request.erasure_tasks.filter( + RequestTask.collection_address == "postgres_example_test_dataset:address" + ).first() + + assert address_task.traversal_details["input_keys"] == [ + "postgres_example_test_dataset:customer", + "postgres_example_test_dataset:employee", + "postgres_example_test_dataset:orders", + "postgres_example_test_dataset:payment_card", + ] + + @pytest.mark.timeout(5) + @pytest.mark.integration + @pytest.mark.integration_postgres + @pytest.mark.integration_mongodb + def test_update_erasure_tasks_with_placeholder_access_data( + self, + db, + privacy_request, + mongo_inserts, + postgres_inserts, + integration_postgres_config, + integration_mongodb_config, + ): + """Test that erasure tasks are updated with the corresponding erasure data collected + from the access task""" + policy = erasure_policy(db, "user.name", "user.contact") + privacy_request.policy_id = policy.id + privacy_request.save(db) + + mongo_dataset, postgres_dataset = combined_mongo_postgresql_graph( + integration_postgres_config, integration_mongodb_config + ) + + graph = DatasetGraph(mongo_dataset, postgres_dataset) + + identity = {"email": mongo_inserts["customer"][0]["email"]} + traversal: Traversal = Traversal(graph, identity) + + traversal_nodes = {} + access_end_nodes = traversal.traverse(traversal_nodes, collect_tasks_fn) + erasure_end_nodes = list(graph.nodes.keys()) + + ready_tasks = persist_new_access_request_tasks( + db, + privacy_request, + traversal, + traversal_nodes, + access_end_nodes, + graph, + ) + + persist_initial_erasure_request_tasks( + db, + privacy_request, + traversal_nodes, + erasure_end_nodes, + graph, + ) + + run_access_node.delay( + privacy_request.id, ready_tasks[0].id, privacy_request_proceed=False + ) + wait_for_tasks_to_complete(db, privacy_request, ActionType.access) + + update_erasure_tasks_with_access_data(db, privacy_request) + + conversations_task = privacy_request.erasure_tasks.filter( + RequestTask.collection_address == "mongo_test:conversations" + ).first() + # Erasure format may save array elements to denote what elements should not be masked while preserving original index + assert conversations_task.get_decoded_data_for_erasures()[1]["thread"] == [ + { + "comment": "com_0013", + "message": "should we text Grace when we land or should we just surprise her?", + "chat_name": "John C", + "ccn": "123456789", + }, + "FIDESOPS_DO_NOT_MASK", + { + "comment": "com_0015", + "message": "Aw but she loves surprises.", + "chat_name": "John C", + "ccn": "123456789", + }, + "FIDESOPS_DO_NOT_MASK", + ] + + def test_erase_after_upstream_and_downstream_tasks( + self, + db, + privacy_request, + saas_erasure_order_config, + saas_erasure_order_connection_config, + saas_erasure_order_dataset_config, + ): + saas_erasure_order_connection_config.update( + db, data={"saas_config": saas_erasure_order_config} + ) + merged_graph = saas_erasure_order_dataset_config.get_graph() + graph = DatasetGraph(merged_graph) + + identity = {"email": "customer-1@example.com"} + traversal: Traversal = Traversal(graph, identity) + + traversal_nodes = {} + _ = traversal.traverse(traversal_nodes, collect_tasks_fn) + erasure_end_nodes = list(graph.nodes.keys()) + + persist_initial_erasure_request_tasks( + db, + privacy_request, + traversal_nodes, + erasure_end_nodes, + graph, + ) + + orders_task = privacy_request.erasure_tasks.filter( + RequestTask.collection_address == "saas_erasure_order_instance:orders" + ).first() + # These are tasks that are specifically marked as "erase_after" + assert orders_task.upstream_tasks == [ + "__ROOT__:__ROOT__", + "saas_erasure_order_instance:orders_to_refunds", + "saas_erasure_order_instance:refunds_to_orders", + ] + assert orders_task.downstream_tasks == ["saas_erasure_order_instance:labels"] + # Data dependencies are still from the root node + assert orders_task.traversal_details["input_keys"] == ["__ROOT__:__ROOT__"] + serialized_collection = orders_task.collection + assert serialized_collection["name"] == "orders" + assert len(serialized_collection["fields"]) == 2 + assert serialized_collection["fields"] == [ + { + "name": "id", + "length": None, + "identity": None, + "is_array": False, + "read_only": None, + "references": [], + "primary_key": True, + "data_categories": ["system.operations"], + "data_type_converter": "integer", + "return_all_elements": None, + }, + { + "name": "email", + "length": None, + "identity": "email", + "is_array": False, + "read_only": None, + "references": [], + "primary_key": False, + "data_categories": None, + "data_type_converter": "None", + "return_all_elements": None, + }, + ] + assert not serialized_collection["skip_processing"] + assert serialized_collection["grouped_inputs"] == [] + assert set(serialized_collection["erase_after"]) == { + "saas_erasure_order_instance:orders_to_refunds", + "saas_erasure_order_instance:refunds_to_orders", + "__ROOT__:__ROOT__", + } + + refunds_task = privacy_request.erasure_tasks.filter( + RequestTask.collection_address == "saas_erasure_order_instance:refunds" + ).first() + # These are tasks that are specifically marked as "erase_after" + assert refunds_task.upstream_tasks == [ + "__ROOT__:__ROOT__", + "saas_erasure_order_instance:orders_to_refunds", + "saas_erasure_order_instance:refunds_to_orders", + ] + assert refunds_task.downstream_tasks == ["saas_erasure_order_instance:labels"] + + labels_task = privacy_request.erasure_tasks.filter( + RequestTask.collection_address == "saas_erasure_order_instance:labels" + ).first() + # Data dependencies are still from the root node + assert labels_task.traversal_details["input_keys"] == ["__ROOT__:__ROOT__"] + # These are tasks that are specifically marked as "erase_after" + assert labels_task.upstream_tasks == [ + "__ROOT__:__ROOT__", + "saas_erasure_order_instance:orders", + "saas_erasure_order_instance:refunds", + ] + assert labels_task.downstream_tasks == ["__TERMINATE__:__TERMINATE__"] + + orders_to_refunds = privacy_request.erasure_tasks.filter( + RequestTask.collection_address + == "saas_erasure_order_instance:orders_to_refunds" + ).first() + # Data dependencies are from orders node though + assert orders_to_refunds.traversal_details["input_keys"] == [ + "saas_erasure_order_instance:orders" + ] + assert orders_to_refunds.upstream_tasks == ["__ROOT__:__ROOT__"] + assert orders_to_refunds.downstream_tasks == [ + "saas_erasure_order_instance:orders", + "saas_erasure_order_instance:refunds", + ] + + refunds_to_order = privacy_request.erasure_tasks.filter( + RequestTask.collection_address + == "saas_erasure_order_instance:refunds_to_orders" + ).first() + # Data dependencies are refunds node though + assert refunds_to_order.traversal_details["input_keys"] == [ + "saas_erasure_order_instance:refunds" + ] + assert refunds_to_order.upstream_tasks == ["__ROOT__:__ROOT__"] + assert refunds_to_order.downstream_tasks == [ + "saas_erasure_order_instance:orders", + "saas_erasure_order_instance:refunds", + ] + + products = privacy_request.erasure_tasks.filter( + RequestTask.collection_address == "saas_erasure_order_instance:products" + ).first() + # Data dependencies are still from the root node + assert products.traversal_details["input_keys"] == ["__ROOT__:__ROOT__"] + # These are tasks that are specifically marked as "erase_after" + assert products.upstream_tasks == ["__ROOT__:__ROOT__"] + assert products.downstream_tasks == ["__TERMINATE__:__TERMINATE__"] + + def test_erase_after_incorrectly_creates_cycle( + self, + db, + privacy_request, + saas_erasure_order_config, + saas_erasure_order_connection_config, + saas_erasure_order_dataset_config, + ): + dataset_name = saas_erasure_order_connection_config.get_saas_config().fides_key + saas_erasure_order_config["endpoints"][0]["erase_after"].append( + f"{dataset_name}.labels" + ) + saas_erasure_order_connection_config.update( + db, data={"saas_config": saas_erasure_order_config} + ) + merged_graph = saas_erasure_order_dataset_config.get_graph() + graph = DatasetGraph(merged_graph) + + identity = {"email": "customer-1@example.com"} + traversal: Traversal = Traversal(graph, identity) + + traversal_nodes = {} + _ = traversal.traverse(traversal_nodes, collect_tasks_fn) + erasure_end_nodes = list(graph.nodes.keys()) + + with pytest.raises(TraversalError): + persist_initial_erasure_request_tasks( + db, + privacy_request, + traversal_nodes, + erasure_end_nodes, + graph, + ) + + def test_no_collections(self, db, privacy_request): + identity = {"email": "customer-1@example.com"} + + graph = DatasetGraph() + + traversal: Traversal = Traversal(graph, identity) + + traversal_nodes = {} + _ = traversal.traverse(traversal_nodes, collect_tasks_fn) + erasure_end_nodes = list(graph.nodes.keys()) + ready_tasks = persist_initial_erasure_request_tasks( + db, + privacy_request, + traversal_nodes, + erasure_end_nodes, + graph, + ) + + assert len(ready_tasks) == 0 + db.refresh(privacy_request) + + assert len(privacy_request.erasure_tasks.all()) == 2 + assert privacy_request.erasure_tasks[0].is_root_task + assert privacy_request.erasure_tasks[0].upstream_tasks == [] + assert privacy_request.erasure_tasks[0].downstream_tasks == [ + TERMINATOR_ADDRESS.value + ] + + assert privacy_request.erasure_tasks[1].is_terminator_task + assert privacy_request.erasure_tasks[1].upstream_tasks == [ + ROOT_COLLECTION_ADDRESS.value + ] + assert privacy_request.erasure_tasks[1].downstream_tasks == [] + + @mock.patch( + "fides.api.task.create_request_tasks.update_erasure_tasks_with_access_data", + ) + @mock.patch( + "fides.api.task.create_request_tasks.queue_request_task", + ) + def test_run_erasure_request_with_existing_request_tasks( + self, + run_erasure_node_mock, + update_erasure_tasks_with_access_data_mock, + request_task, + erasure_request_task, + db, + privacy_request, + policy, + ): + assert privacy_request.access_tasks.count() == 3 + assert privacy_request.erasure_tasks.count() == 3 + + # The ready tasks here are all the nodes connected to the erasure node + ready = run_erasure_request( + privacy_request, + db, + privacy_request_proceed=False, + ) + + assert len(ready) == 1 + ready_task = ready[0] + assert not ready_task.is_root_task + assert ready_task == erasure_request_task + + assert ready_task.status == ExecutionLogStatus.pending + assert ready_task.action_type == ActionType.erasure + + assert update_erasure_tasks_with_access_data_mock.called + update_erasure_tasks_with_access_data_mock.called_with(db, privacy_request) + assert run_erasure_node_mock.called + run_erasure_node_mock.assert_called_with(erasure_request_task, False) + + +class TestPersistConsentRequestTasks: + def test_persist_new_consent_request_tasks( + self, + db, + privacy_request, + google_analytics_dataset_config_no_secrets, + ): + graph = build_consent_dataset_graph( + [google_analytics_dataset_config_no_secrets] + ) + + traversal_nodes = {} + # Unlike erasure and access graphs, we don't call traversal.traverse, but build a simpler + # graph that just has one node per dataset + for col_address, node in graph.nodes.items(): + traversal_node = TraversalNode(node) + traversal_nodes[col_address] = traversal_node + + ready_tasks = persist_new_consent_request_tasks( + db, privacy_request, traversal_nodes, {"ga_client_id": "test_id"}, graph + ) + + assert len(ready_tasks) == 1 + root_task = ready_tasks[0] + assert root_task.is_root_task + assert root_task.action_type == ActionType.consent + assert root_task.upstream_tasks == [] + assert root_task.downstream_tasks == [ + "google_analytics_instance:google_analytics_instance" + ] + assert root_task.all_descendant_tasks == [ + "__TERMINATE__:__TERMINATE__", + "google_analytics_instance:google_analytics_instance", + ] + assert root_task.status == ExecutionLogStatus.complete + assert root_task.access_data == '[{"ga_client_id": "test_id"}]' + assert root_task.get_decoded_access_data() == [{"ga_client_id": "test_id"}] + terminator_task = privacy_request.get_terminate_task_by_action( + ActionType.consent + ) + + assert terminator_task.is_terminator_task + assert terminator_task.action_type == ActionType.consent + assert terminator_task.upstream_tasks == [ + "google_analytics_instance:google_analytics_instance" + ] + assert terminator_task.downstream_tasks == [] + assert terminator_task.all_descendant_tasks == [] + assert terminator_task.status == ExecutionLogStatus.pending + + ga_task = privacy_request.consent_tasks.filter( + RequestTask.collection_address + == "google_analytics_instance:google_analytics_instance", + ).first() + assert not ga_task.is_root_task + assert not ga_task.is_terminator_task + assert ga_task.action_type == ActionType.consent + # Consent nodes have no data dependencies - they just have the root upstream + # and the terminate node downstream + assert ga_task.upstream_tasks == ["__ROOT__:__ROOT__"] + assert ga_task.downstream_tasks == ["__TERMINATE__:__TERMINATE__"] + assert ga_task.all_descendant_tasks == ["__TERMINATE__:__TERMINATE__"] + assert ga_task.status == ExecutionLogStatus.pending + + # The collection is a fake one for Consent, since requests happen at the dataset level + assert ga_task.collection == { + "name": "google_analytics_instance", + "after": [], + "fields": [], + "erase_after": [], + "grouped_inputs": [], + "skip_processing": False, + } + assert ga_task.traversal_details == { + "input_keys": [], + "incoming_edges": [], + "outgoing_edges": [], + "dataset_connection_key": "google_analytics_instance", + } + + @mock.patch( + "fides.api.task.create_request_tasks.queue_request_task", + ) + def test_run_consent_request_no_request_tasks_existing( + self, run_consent_node_mock, db, privacy_request, policy + ): + ready = run_consent_request( + privacy_request, + DatasetGraph(), + {"email": "customer-4@example.com"}, + db, + privacy_request_proceed=False, + ) + + assert len(ready) == 1 + root_task = ready[0] + assert root_task.is_root_task + + assert run_consent_node_mock.called + run_consent_node_mock.assert_called_with(root_task, False) + + @mock.patch( + "fides.api.task.create_request_tasks.queue_request_task", + ) + def test_reprocess_consent_request_with_existing_request_tasks( + self, run_consent_node_mock, consent_request_task, db, privacy_request, policy + ): + assert privacy_request.consent_tasks.count() == 3 + + ready = run_consent_request( + privacy_request, + DatasetGraph(), + {"email": "customer-4@example.com"}, + db, + privacy_request_proceed=False, + ) + + assert len(ready) == 1 + ready_task = ready[0] + assert ready_task == consent_request_task + assert not ready_task.is_root_task + assert ready_task.action_type == ActionType.consent + assert ready_task.status == ExecutionLogStatus.pending + + assert run_consent_node_mock.called + run_consent_node_mock.assert_called_with(consent_request_task, False) + + # No new consent tasks created + assert privacy_request.consent_tasks.count() == 3 + + +class TestGetExistingReadyTasks: + def test_no_request_tasks(self, privacy_request, db): + assert get_existing_ready_tasks(db, privacy_request, ActionType.access) == [] + + def test_task_should_be_same_action_type(self, privacy_request, db): + rt = RequestTask.create( + db, + data={ + "privacy_request_id": privacy_request.id, + "collection_address": "dataset:collection", + "collection_name": "collection", + "dataset_name": "dataset", + "action_type": ActionType.erasure, + "status": "pending", + }, + ) + assert get_existing_ready_tasks(db, privacy_request, ActionType.access) == [] + rt.delete(db) + + def test_task_must_be_incomplete(self, privacy_request, db): + rt = RequestTask.create( + db, + data={ + "privacy_request_id": privacy_request.id, + "collection_address": "dataset:collection", + "collection_name": "collection", + "dataset_name": "dataset", + "action_type": ActionType.erasure, + "status": "complete", + }, + ) + assert get_existing_ready_tasks(db, privacy_request, ActionType.access) == [] + rt.delete(db) + + def test_task_needs_to_have_upstream_complete(self, privacy_request, db): + upstream = RequestTask.create( + db, + data={ + "privacy_request_id": privacy_request.id, + "collection_address": "dataset:other_collection", + "collection_name": "other_collection", + "dataset_name": "dataset", + "action_type": ActionType.access, + "status": "pending", + "upstream_tasks": [], + }, + ) + rt = RequestTask.create( + db, + data={ + "privacy_request_id": privacy_request.id, + "collection_address": "dataset:collection", + "collection_name": "collection", + "dataset_name": "dataset", + "action_type": ActionType.access, + "status": "pending", + "upstream_tasks": [upstream.collection_address], + }, + ) + # rt is not ready but upstream is + assert get_existing_ready_tasks(db, privacy_request, ActionType.access) == [ + upstream + ] + rt.delete(db) + + def test_error_status_is_marked_as_pending(self, privacy_request, db): + upstream = RequestTask.create( + db, + data={ + "privacy_request_id": privacy_request.id, + "collection_address": "dataset:other_collection", + "collection_name": "other_collection", + "dataset_name": "dataset", + "action_type": ActionType.access, + "status": "pending", + "upstream_tasks": [], + }, + ) + rt = RequestTask.create( + db, + data={ + "privacy_request_id": privacy_request.id, + "collection_address": "dataset:collection", + "collection_name": "collection", + "dataset_name": "dataset", + "action_type": ActionType.access, + "status": "error", + "upstream_tasks": [upstream.collection_address], + }, + ) + # rt is not ready but upstream is + assert get_existing_ready_tasks(db, privacy_request, ActionType.access) == [ + upstream + ] + db.refresh(rt) + # The current "errored" task is marked as pending, even if its upstream + # tasks aren't ready + assert rt.status == ExecutionLogStatus.pending + upstream.delete(db) + rt.delete(db) + + def test_ready_tasks(self, privacy_request, db): + rt = RequestTask.create( + db, + data={ + "privacy_request_id": privacy_request.id, + "collection_address": "dataset:collection", + "collection_name": "collection", + "dataset_name": "dataset", + "action_type": ActionType.access, + "status": "pending", + }, + ) + assert get_existing_ready_tasks(db, privacy_request, ActionType.access) == [rt] + + +class TestRunAccessRequestWithRequestTasks: + @pytest.mark.timeout(5) + @pytest.mark.integration + @pytest.mark.integration_postgres + @pytest.mark.integration_mongodb + def test_run_access_request( + self, + db, + privacy_request, + policy, + mongo_inserts, + postgres_inserts, + postgres_integration_db, + integration_mongodb_config, + integration_postgres_config, + ): + mongo_dataset, postgres_dataset = combined_mongo_postgresql_graph( + integration_postgres_config, integration_mongodb_config + ) + + graph = DatasetGraph(mongo_dataset, postgres_dataset) + + identity = {"email": mongo_inserts["customer"][0]["email"]} + + run_access_request( + privacy_request, + policy, + graph, + [integration_postgres_config, integration_mongodb_config], + identity, + db, + privacy_request_proceed=True, + ) + wait_for_tasks_to_complete(db, privacy_request, ActionType.access) + + assert privacy_request.access_tasks.count() == 16 + assert privacy_request.erasure_tasks.count() == 0 + + all_access_tasks = privacy_request.access_tasks.all() + + assert {t.collection_address for t in all_access_tasks} == { + "__ROOT__:__ROOT__", + "mongo_test:customer_feedback", + "postgres_example:customer", + "mongo_test:internal_customer_profile", + "mongo_test:address", + "postgres_example:orders", + "mongo_test:orders", + "mongo_test:customer_details", + "mongo_test:rewards", + "postgres_example:payment_card", + "mongo_test:conversations", + "mongo_test:flights", + "postgres_example:address", + "mongo_test:aircraft", + "mongo_test:employee", + "__TERMINATE__:__TERMINATE__", + } + assert all(t.status == ExecutionLogStatus.complete for t in all_access_tasks) + db.refresh(privacy_request) + assert privacy_request.status == PrivacyRequestStatus.complete + + raw_access_results = privacy_request.get_raw_access_results() + + # Two addresses being found tests that our input_keys are working properly + assert [ + address["id"] for address in raw_access_results["postgres_example:address"] + ] == [1000, 1002] + + customer_details = raw_access_results["mongo_test:customer_details"][0] + assert customer_details["customer_id"] == 10000 + assert customer_details["gender"] == "male" + assert customer_details["birthday"] == datetime(1988, 1, 10, 0, 0) + assert customer_details["workplace_info"] == { + "employer": "Green Tea Company", + "position": "Head Grower", + "direct_reports": ["Margo Robbins"], + } + assert customer_details["emergency_contacts"] == [ + { + "name": "Grace Customer", + "relationship": "mother", + "phone": "123-456-7890", + }, + { + "name": "Joseph Customer", + "relationship": "brother", + "phone": "000-000-0000", + }, + ] + assert customer_details["children"] == ["Kent Customer", "Kenny Customer"] + + @pytest.mark.timeout(5) + @pytest.mark.integration + @pytest.mark.integration_postgres + @pytest.mark.integration_mongodb + def test_run_access_request_with_error( + self, + db, + privacy_request, + policy, + mongo_inserts, + postgres_inserts, + integration_mongodb_config, + integration_postgres_config, + ): + mongo_dataset, postgres_dataset = combined_mongo_postgresql_graph( + integration_postgres_config, integration_mongodb_config + ) + + graph = DatasetGraph(mongo_dataset, postgres_dataset) + + identity = {"email": mongo_inserts["customer"][0]["email"]} + + # Temporarily remove the secrets from the mongo connection to prevent execution from occurring + saved_secrets = integration_mongodb_config.secrets + integration_mongodb_config.secrets = {} + integration_mongodb_config.save(db) + + run_access_request( + privacy_request, + policy, + graph, + [integration_postgres_config, integration_mongodb_config], + {"email": mongo_inserts["customer"][0]["email"]}, + db, + privacy_request_proceed=True, + ) + wait_for_tasks_to_complete(db, privacy_request, ActionType.access) + + assert privacy_request.access_tasks.count() == 16 + assert privacy_request.erasure_tasks.count() == 0 + + postgres_customer_task = privacy_request.access_tasks.filter( + RequestTask.collection_address == "postgres_example:address" + ).first() + customer_task_updated = postgres_customer_task.updated_at + + mongo_flights_task = privacy_request.access_tasks.filter( + RequestTask.collection_address == "mongo_test:flights" + ).first() + mongo_flights_task_updated = mongo_flights_task.updated_at + + # Mongo tasks are marked as error but the postgres tasks are still + # able to complete. + task_statuses = { + request_task.collection_address: request_task.status.value + for request_task in privacy_request.access_tasks + } + assert task_statuses == { + "__ROOT__:__ROOT__": "complete", + "mongo_test:customer_feedback": "error", + "postgres_example:customer": "complete", + "mongo_test:internal_customer_profile": "error", + "mongo_test:orders": "error", + "mongo_test:customer_details": "error", + "mongo_test:address": "error", + "postgres_example:orders": "complete", + "mongo_test:rewards": "error", + "mongo_test:flights": "error", + "mongo_test:conversations": "error", + "postgres_example:payment_card": "complete", + "mongo_test:aircraft": "error", + "mongo_test:employee": "error", + "postgres_example:address": "complete", + "__TERMINATE__:__TERMINATE__": "error", + } + + integration_mongodb_config.secrets = saved_secrets + integration_mongodb_config.save(db) + + run_access_request( + privacy_request, + policy, + graph, + [integration_postgres_config, integration_mongodb_config], + {"email": mongo_inserts["customer"][0]["email"]}, + db, + privacy_request_proceed=True, + ) + wait_for_tasks_to_complete(db, privacy_request, ActionType.access) + + # No new tasks were created - we just updated the statuses of the old ones + assert privacy_request.access_tasks.count() == 16 + assert privacy_request.erasure_tasks.count() == 0 + + assert all( + t.status == ExecutionLogStatus.complete + for t in privacy_request.access_tasks + ) + db.refresh(privacy_request) + assert privacy_request.status == PrivacyRequestStatus.complete + + # These results are not yet filtered by data category + raw_results = privacy_request.get_raw_access_results() + + # Selected postgres results - retrieved first pass + customer_info = raw_results["postgres_example:customer"][0] + assert customer_info["id"] == 10000 + assert customer_info["email"] == "test_one@example.com" + assert customer_info["address_id"] == 1000 + + # Existing task was unchanged on re-run because it was already completed + db.refresh(postgres_customer_task) + assert postgres_customer_task.updated_at == customer_task_updated + + # Selected Mongo results - retrieved second pass + flight_info = raw_results["mongo_test:flights"][0] + assert flight_info["passenger_information"] == { + "passenger_ids": ["D222-22221"], + "full_name": "John Customer", + } + assert flight_info["flight_no"] == "AA230" + assert flight_info["pilots"] == ["3", "4"] + # Existing task was modified + db.refresh(mongo_flights_task) + assert mongo_flights_task.updated_at > mongo_flights_task_updated + + +class TestRunErasureRequestWithRequestTasks: + @pytest.mark.timeout(15) + @pytest.mark.integration + @pytest.mark.integration_postgres + @pytest.mark.integration_mongodb + def test_run_erasure_request( + self, + db, + mongo_inserts, + postgres_inserts, + privacy_request_with_erasure_policy, + erasure_policy, + example_datasets, + postgres_integration_db, + integration_mongodb_config, + integration_postgres_config, + ): + """Large test handling access and erasure with a failed erasure step""" + mongo_dataset, postgres_dataset = combined_mongo_postgresql_graph( + integration_postgres_config, integration_mongodb_config + ) + + field( + [mongo_dataset], "mongo_test", "conversations", "thread", "chat_name" + ).data_categories = ["user.name"] + field( + [postgres_dataset], "postgres_example", "customer", "name" + ).data_categories = ["user.name"] + field( + [mongo_dataset], + "mongo_test", + "customer_details", + "workplace_info", + "direct_reports", + ).data_categories = ["user.name"] + field( + [mongo_dataset], + "mongo_test", + "customer_details", + "emergency_contacts", + "name", + ).data_categories = ["user.name"] + field( + [mongo_dataset], + "mongo_test", + "flights", + "passenger_information", + "full_name", + ).data_categories = ["user.name"] + field([mongo_dataset], "mongo_test", "employee", "name").data_categories = [ + "user.name" + ] + + graph = DatasetGraph(mongo_dataset, postgres_dataset) + + identity = {"email": mongo_inserts["customer"][0]["email"]} + + CONFIG.execution.task_retry_count = 0 + CONFIG.execution.task_retry_delay = 0.1 + CONFIG.execution.task_retry_backoff = 0.01 + p = mock.patch( + "fides.api.service.connectors.MongoDBConnector.mask_data", + new=MagicMock(side_effect=Exception("Key Error")), + ) + p.start() + + assert privacy_request_with_erasure_policy.access_tasks.count() == 0 + assert privacy_request_with_erasure_policy.erasure_tasks.count() == 0 + + run_access_request( + privacy_request_with_erasure_policy, + erasure_policy, + graph, + [integration_postgres_config, integration_mongodb_config], + identity, + db, + privacy_request_proceed=False, + ) + wait_for_tasks_to_complete( + db, privacy_request_with_erasure_policy, ActionType.access + ) + assert privacy_request_with_erasure_policy.access_tasks.count() == 16 + # These were created preemptively alongside the access request tasks so they match + assert privacy_request_with_erasure_policy.erasure_tasks.count() == 16 + + # Run erasure portion first time, but it is expected to fail because + # Mongo connector is not working + run_erasure_request( + privacy_request_with_erasure_policy, db, privacy_request_proceed=False + ) + wait_for_tasks_to_complete( + db, privacy_request_with_erasure_policy, ActionType.erasure + ) + + postgres_customer_task = ( + privacy_request_with_erasure_policy.erasure_tasks.filter( + RequestTask.collection_address == "postgres_example:address" + ).first() + ) + customer_task_updated = postgres_customer_task.updated_at + + mongo_flights_task = privacy_request_with_erasure_policy.erasure_tasks.filter( + RequestTask.collection_address == "mongo_test:flights" + ).first() + mongo_flights_task_updated = mongo_flights_task.updated_at + + # Mongo tasks are marked as error but the postgres tasks are still + # able to complete. + db.refresh(privacy_request_with_erasure_policy) + task_statuses = { + request_task.collection_address: request_task.status.value + for request_task in privacy_request_with_erasure_policy.erasure_tasks + } + assert task_statuses == { + "__ROOT__:__ROOT__": "complete", + "mongo_test:internal_customer_profile": "error", + "mongo_test:rewards": "error", + "postgres_example:customer": "complete", + "mongo_test:customer_feedback": "error", + "mongo_test:employee": "error", + "mongo_test:address": "error", + "postgres_example:payment_card": "complete", + "mongo_test:orders": "error", + "mongo_test:customer_details": "error", + "postgres_example:orders": "complete", + "postgres_example:address": "complete", + "mongo_test:flights": "error", + "mongo_test:conversations": "error", + "mongo_test:aircraft": "error", + "__TERMINATE__:__TERMINATE__": "error", + } + + # Stop mocking MongoDBConnector.mask_data + p.stop() + + # Run erasure one more time + run_erasure_request( + privacy_request_with_erasure_policy, db, privacy_request_proceed=False + ) + wait_for_tasks_to_complete( + db, privacy_request_with_erasure_policy, ActionType.erasure + ) + + assert all( + t.status == ExecutionLogStatus.complete + for t in privacy_request_with_erasure_policy.erasure_tasks + ) + + rows_masked = privacy_request_with_erasure_policy.get_raw_masking_counts() + + # Existing completed task was not touched on Run #2 + db.refresh(postgres_customer_task) + assert postgres_customer_task.updated_at == customer_task_updated + + # Existing error task was modified on run #2 + db.refresh(mongo_flights_task) + assert mongo_flights_task.updated_at > mongo_flights_task_updated + + # No new tasks were created + assert privacy_request_with_erasure_policy.erasure_tasks.count() == 16 + + assert rows_masked == { + "mongo_test:rewards": 0, + "mongo_test:customer_feedback": 0, + "postgres_example:customer": 1, + "mongo_test:employee": 2, + "mongo_test:internal_customer_profile": 0, + "postgres_example:payment_card": 0, + "mongo_test:address": 0, + "mongo_test:orders": 0, + "postgres_example:orders": 0, + "postgres_example:address": 0, + "mongo_test:customer_details": 1, + "mongo_test:conversations": 2, + "mongo_test:flights": 1, + "mongo_test:aircraft": 0, + } + + # Remove request tasks and re-run access request + db.query(RequestTask).filter( + RequestTask.privacy_request_id == privacy_request_with_erasure_policy.id + ).delete() + run_access_request( + privacy_request_with_erasure_policy, + erasure_policy, + graph, + [integration_postgres_config, integration_mongodb_config], + identity, + db, + privacy_request_proceed=False, + ) + wait_for_tasks_to_complete( + db, privacy_request_with_erasure_policy, ActionType.access + ) + raw_access_results = ( + privacy_request_with_erasure_policy.get_raw_access_results() + ) + # erasure policy targeted names with null rewrite strategy + assert raw_access_results["postgres_example:customer"][0]["name"] is None + assert ( + raw_access_results["mongo_test:conversations"][0]["thread"][0]["chat_name"] + is None + ) + assert ( + raw_access_results["mongo_test:conversations"][1]["thread"][0]["chat_name"] + is None + ) + assert ( + raw_access_results["mongo_test:conversations"][1]["thread"][1]["chat_name"] + is None + ) + assert raw_access_results["mongo_test:employee"][0]["name"] is None + assert raw_access_results["mongo_test:employee"][1]["name"] is None + assert raw_access_results["mongo_test:customer_details"][0]["workplace_info"][ + "direct_reports" + ] == [None] + assert not raw_access_results["mongo_test:customer_details"][0][ + "emergency_contacts" + ][0]["name"] + assert not raw_access_results["mongo_test:customer_details"][0][ + "emergency_contacts" + ][1]["name"] + assert not raw_access_results["mongo_test:flights"][0]["passenger_information"][ + "full_name" + ] diff --git a/tests/ops/task/test_execute_request_tasks.py b/tests/ops/task/test_execute_request_tasks.py new file mode 100644 index 0000000000..c3c5f553da --- /dev/null +++ b/tests/ops/task/test_execute_request_tasks.py @@ -0,0 +1,521 @@ +import pytest +from sqlalchemy.orm import Session + +from fides.api.common_exceptions import ( + PrivacyRequestCanceled, + PrivacyRequestNotFound, + RequestTaskNotFound, + ResumeTaskException, + UpstreamTasksNotReady, +) +from fides.api.graph.config import ( + ROOT_COLLECTION_ADDRESS, + TERMINATOR_ADDRESS, + Collection, + CollectionAddress, + FieldAddress, + FieldPath, +) +from fides.api.graph.execution import ExecutionNode +from fides.api.graph.graph import DatasetGraph, Edge +from fides.api.graph.traversal import Traversal +from fides.api.models.connectionconfig import ConnectionConfig +from fides.api.models.privacy_request import ( + ExecutionLogStatus, + PrivacyRequestStatus, + RequestTask, +) +from fides.api.schemas.policy import ActionType +from fides.api.service.connectors import PostgreSQLConnector +from fides.api.task.create_request_tasks import ( + collect_tasks_fn, + persist_new_access_request_tasks, +) +from fides.api.task.execute_request_tasks import ( + can_run_task_body, + create_graph_task, + run_prerequisite_task_checks, +) +from fides.api.task.graph_runners import use_dsr_3_0_scheduler +from fides.api.task.graph_task import mark_current_and_downstream_nodes_as_failed +from fides.api.task.task_resources import TaskResources +from fides.api.util.cache import FidesopsRedis, get_cache + + +def _collect_task_resources( + session: Session, request_task: RequestTask +) -> TaskResources: + """Build the TaskResources artifact which just collects some Database resources needed for the current task + Currently just used for testing - + """ + return TaskResources( + request_task.privacy_request, + request_task.privacy_request.policy, + session.query(ConnectionConfig).all(), + request_task, + session, + ) + + +@pytest.fixture() +def create_postgres_access_request_tasks(postgres_dataset_graph, db, privacy_request): + identity = {"email": "customer-1@example.com"} + traversal: Traversal = Traversal(postgres_dataset_graph, identity) + traversal_nodes = {} + end_nodes = traversal.traverse(traversal_nodes, collect_tasks_fn) + + _ = persist_new_access_request_tasks( + db, + privacy_request, + traversal, + traversal_nodes, + end_nodes, + postgres_dataset_graph, + ) + + +class TestRunPrerequisiteTaskChecks: + def test_privacy_request_does_not_exist(self, db): + with pytest.raises(PrivacyRequestNotFound): + run_prerequisite_task_checks(db, "12345", "12345") + + def test_request_task_does_not_exist(self, db, privacy_request): + with pytest.raises(RequestTaskNotFound): + run_prerequisite_task_checks(db, privacy_request.id, "12345") + + def test_privacy_request_was_cancelled(self, db, privacy_request): + privacy_request.status = PrivacyRequestStatus.canceled + privacy_request.save(db) + + with pytest.raises(PrivacyRequestCanceled): + run_prerequisite_task_checks(db, privacy_request.id, "12345") + + @pytest.mark.usefixtures("request_task") + def test_upstream_tasks_not_complete(self, db, privacy_request): + terminator_task = privacy_request.access_tasks.filter( + RequestTask.collection_address == TERMINATOR_ADDRESS.value + ).first() + + with pytest.raises(UpstreamTasksNotReady): + # Upstream request task is not ready + run_prerequisite_task_checks(db, privacy_request.id, terminator_task.id) + + def test_upstream_tasks_complete(self, db, privacy_request, request_task): + # Root task is completed so downstream request task can run + root_task = privacy_request.access_tasks.filter( + RequestTask.collection_address == ROOT_COLLECTION_ADDRESS.value + ).first() + pr, rt, ur = run_prerequisite_task_checks( + db, privacy_request.id, request_task.id + ) + assert pr == privacy_request + assert rt == request_task + assert ur.all() == [root_task] + + # Request task is skipped so downstream terminator task can run + terminator_task = privacy_request.access_tasks.filter( + RequestTask.collection_address == TERMINATOR_ADDRESS.value + ).first() + request_task.update_status(db, ExecutionLogStatus.skipped) + + pr, rt, ur = run_prerequisite_task_checks( + db, privacy_request.id, terminator_task.id + ) + assert ur.all() == [request_task] + + +class TestCreateGraphTask: + @pytest.mark.usefixtures("create_postgres_access_request_tasks") + def test_create_graph_task(self, db, privacy_request): + """Request Tasks from the database can be re-hydrated into Graph Tasks""" + + request_task = privacy_request.access_tasks.filter( + RequestTask.collection_address == "postgres_example_test_dataset:address" + ).first() + resources = _collect_task_resources(db, request_task) + + graph_task = create_graph_task(db, request_task, resources) + + assert graph_task.request_task == request_task + assert graph_task.key == CollectionAddress.from_string( + "postgres_example_test_dataset:address" + ) + + execution_node = graph_task.execution_node + assert isinstance(execution_node.collection, Collection) + assert execution_node.address == CollectionAddress.from_string( + "postgres_example_test_dataset:address" + ) + assert isinstance(graph_task.connector, PostgreSQLConnector) + + @pytest.mark.usefixtures("create_postgres_access_request_tasks") + def test_error_hydrating_graph_task(self, db, privacy_request): + """If GraphTask cannot be hydrated, error is thrown, current task and downstream tasks + are marked as error and execution log created for current node + + Normally the Graph Task would take care of this, but in this case, we couldn't create + the graph task in the first place + """ + + request_task = privacy_request.access_tasks.filter( + RequestTask.collection_address == "postgres_example_test_dataset:address" + ).first() + # Set required field to None on RequestTask.collection + request_task.collection["name"] = None + request_task.save(db) + + resources = _collect_task_resources(db, request_task) + + with pytest.raises(ResumeTaskException): + create_graph_task(db, request_task, resources) + + db.refresh(request_task) + + downstream_task = privacy_request.access_tasks.filter( + RequestTask.collection_address == request_task.downstream_tasks[0] + ).first() + assert downstream_task.status == ExecutionLogStatus.error + + execution_log = privacy_request.execution_logs.first() + assert execution_log.dataset_name == "postgres_example_test_dataset" + assert execution_log.collection_name == "address" + assert execution_log.action_type == ActionType.access + assert execution_log.status == ExecutionLogStatus.error + + +class TestExecutionNode: + @pytest.fixture() + @pytest.mark.usefixtures("create_postgres_access_request_tasks") + def address_execution_node( + self, privacy_request, create_postgres_access_request_tasks + ): + request_task = privacy_request.access_tasks.filter( + RequestTask.collection_address == "postgres_example_test_dataset:address" + ).first() + + execution_node = ExecutionNode(request_task) + return execution_node + + @pytest.fixture() + @pytest.mark.usefixtures("create_postgres_access_request_tasks") + def employee_execution_node( + self, privacy_request, create_postgres_access_request_tasks + ): + request_task = privacy_request.access_tasks.filter( + RequestTask.collection_address == "postgres_example_test_dataset:employee" + ).first() + + execution_node = ExecutionNode(request_task) + return execution_node + + def test_collection_address(self, address_execution_node): + assert isinstance(address_execution_node.collection, Collection) + assert address_execution_node.address == CollectionAddress.from_string( + "postgres_example_test_dataset:address" + ) + + def test_incoming_edges(self, address_execution_node): + """Assert incoming edges are hydrated from the Traversal details saved on the Request Task""" + assert address_execution_node.incoming_edges == { + Edge( + FieldAddress.from_string( + "postgres_example_test_dataset:payment_card:billing_address_id" + ), + FieldAddress.from_string("postgres_example_test_dataset:address:id"), + ), + Edge( + FieldAddress.from_string( + "postgres_example_test_dataset:customer:address_id" + ), + FieldAddress.from_string("postgres_example_test_dataset:address:id"), + ), + Edge( + FieldAddress.from_string( + "postgres_example_test_dataset:orders:shipping_address_id" + ), + FieldAddress.from_string("postgres_example_test_dataset:address:id"), + ), + Edge( + FieldAddress.from_string( + "postgres_example_test_dataset:employee:address_id" + ), + FieldAddress.from_string("postgres_example_test_dataset:address:id"), + ), + } + assert address_execution_node.outgoing_edges == set() + + def test_input_keys(self, address_execution_node): + """Assert input keys are hydrated from the Traversal details saved on the Request Task""" + + assert address_execution_node.input_keys == [ + CollectionAddress.from_string("postgres_example_test_dataset:customer"), + CollectionAddress.from_string("postgres_example_test_dataset:employee"), + CollectionAddress.from_string("postgres_example_test_dataset:orders"), + CollectionAddress.from_string("postgres_example_test_dataset:payment_card"), + ] + + def test_outgoing_edges(self, employee_execution_node): + """Assert outgoing edges are hydrated from the Traversal details saved on the Request Task""" + + assert employee_execution_node.outgoing_edges == { + Edge( + FieldAddress.from_string("postgres_example_test_dataset:employee:id"), + FieldAddress.from_string( + "postgres_example_test_dataset:service_request:employee_id" + ), + ), + Edge( + FieldAddress.from_string( + "postgres_example_test_dataset:employee:address_id" + ), + FieldAddress.from_string("postgres_example_test_dataset:address:id"), + ), + } + + def test_incoming_edges_by_collection(self, address_execution_node): + """Assert incoming_edges_from_collection are built from incoming edges saved on the traversal details""" + assert address_execution_node.incoming_edges_by_collection == { + CollectionAddress.from_string("postgres_example_test_dataset:customer"): [ + Edge( + FieldAddress.from_string( + "postgres_example_test_dataset:customer:address_id" + ), + FieldAddress.from_string( + "postgres_example_test_dataset:address:id" + ), + ) + ], + CollectionAddress.from_string( + "postgres_example_test_dataset:payment_card" + ): [ + Edge( + FieldAddress.from_string( + "postgres_example_test_dataset:payment_card:billing_address_id" + ), + FieldAddress.from_string( + "postgres_example_test_dataset:address:id" + ), + ) + ], + CollectionAddress.from_string("postgres_example_test_dataset:orders"): [ + Edge( + FieldAddress.from_string( + "postgres_example_test_dataset:orders:shipping_address_id" + ), + FieldAddress.from_string( + "postgres_example_test_dataset:address:id" + ), + ) + ], + CollectionAddress.from_string("postgres_example_test_dataset:employee"): [ + Edge( + FieldAddress.from_string( + "postgres_example_test_dataset:employee:address_id" + ), + FieldAddress.from_string( + "postgres_example_test_dataset:address:id" + ), + ) + ], + } + + @pytest.mark.usefixtures("sentry_connection_config_without_secrets") + def test_grouped_fields( + self, db, privacy_request, sentry_dataset_config_without_secrets + ): + """Test that a config with grouped inputs (sentry saas connector) has grouped inputs persisted""" + merged_graph = sentry_dataset_config_without_secrets.get_graph() + graph = DatasetGraph(merged_graph) + + identity = {"email": "customer-1@example.com"} + traversal: Traversal = Traversal(graph, identity) + traversal_nodes = {} + end_nodes = traversal.traverse(traversal_nodes, collect_tasks_fn) + + _ = persist_new_access_request_tasks( + db, + privacy_request, + traversal, + traversal_nodes, + end_nodes, + graph, + ) + + issues_task = privacy_request.access_tasks.filter( + RequestTask.collection_address == "sentry_dataset:issues" + ).first() + execution_node = ExecutionNode(issues_task) + assert execution_node.grouped_fields == { + "project_slug", + "query", + "organization_slug", + } + + def test_query_field_paths(self, address_execution_node, employee_execution_node): + assert address_execution_node.query_field_paths == { + FieldPath( + "id", + ) + } + assert employee_execution_node.query_field_paths == { + FieldPath("email"), + } + + def test_dependent_identity_fields(self, address_execution_node): + assert not address_execution_node.dependent_identity_fields + + # Edit node to add a grouped field that also is an identity field + address_execution_node.grouped_fields = { + address_execution_node.collection.fields[0].name + } + address_execution_node.collection.fields[0].identity = "email" + assert address_execution_node.dependent_identity_fields + + def test_build_incoming_field_path_maps(self, address_execution_node): + """Light test of most common path, the first tuple""" + field_path_maps = address_execution_node.build_incoming_field_path_maps()[0] + + assert field_path_maps[ + CollectionAddress("postgres_example_test_dataset", "employee") + ] == [ + ( + FieldPath( + "address_id", + ), + FieldPath( + "id", + ), + ) + ] + assert field_path_maps[ + CollectionAddress("postgres_example_test_dataset", "customer") + ] == [ + ( + FieldPath( + "address_id", + ), + FieldPath( + "id", + ), + ) + ] + assert field_path_maps[ + CollectionAddress("postgres_example_test_dataset", "orders") + ] == [ + ( + FieldPath( + "shipping_address_id", + ), + FieldPath( + "id", + ), + ) + ] + assert field_path_maps[ + CollectionAddress("postgres_example_test_dataset", "payment_card") + ] == [ + ( + FieldPath( + "billing_address_id", + ), + FieldPath( + "id", + ), + ) + ] + + def test_typed_filtered_values(self, address_execution_node): + assert address_execution_node.typed_filtered_values({"id": [1, 2]}) == { + "id": [1, 2] + } + assert ( + address_execution_node.typed_filtered_values({"non_existent_id": [1, 2]}) + == {} + ) + + +class TestCanRunTaskBody: + def test_task_is_pending(self, request_task): + assert request_task.status == ExecutionLogStatus.pending + assert can_run_task_body(request_task) + + def test_task_is_skipped(self, db, request_task): + request_task.update_status(db, ExecutionLogStatus.skipped) + assert not can_run_task_body(request_task) + + def test_task_is_error(self, db, request_task): + request_task.update_status(db, ExecutionLogStatus.error) + # Error states need to be set to pending when reprocessing + assert not can_run_task_body(request_task) + + def test_task_is_complete(self, db, request_task): + request_task.update_status(db, ExecutionLogStatus.complete) + assert not can_run_task_body(request_task) + + @pytest.mark.usefixtures("request_task") + def test_task_is_root(self, privacy_request): + root_task = privacy_request.get_root_task_by_action(ActionType.access) + assert root_task.status == ExecutionLogStatus.complete + assert not can_run_task_body(root_task) + + @pytest.mark.usefixtures("request_task") + def test_task_is_terminator(self, privacy_request): + terminator_task = privacy_request.get_terminate_task_by_action( + ActionType.access + ) + assert terminator_task.status == ExecutionLogStatus.pending + assert not can_run_task_body(terminator_task) + + +class TestMarkCurrentAndDownstreamNodesAsFailed: + def test_mark_tasks_as_failed( + self, db, privacy_request, request_task, erasure_request_task + ): + root_task = privacy_request.get_root_task_by_action(ActionType.access) + terminator_task = privacy_request.get_terminate_task_by_action( + ActionType.access + ) + assert request_task.status == ExecutionLogStatus.pending + assert terminator_task.status == ExecutionLogStatus.pending + + mark_current_and_downstream_nodes_as_failed(request_task, db) + + db.refresh(root_task) + db.refresh(request_task) + db.refresh(terminator_task) + db.refresh(erasure_request_task) + + # Upstream task unaffected + assert root_task.status == ExecutionLogStatus.complete + # Both current task and terminator task marked as error + assert request_task.status == ExecutionLogStatus.error + assert terminator_task.status == ExecutionLogStatus.error + # Task of a different action type unaffected + assert erasure_request_task.status == ExecutionLogStatus.pending + + +class TestGetDSRVersion: + @pytest.mark.usefixtures("use_dsr_2_0") + def test_use_dsr_2_0(self, privacy_request): + assert use_dsr_3_0_scheduler(privacy_request, ActionType.access) is False + + @pytest.mark.usefixtures("use_dsr_3_0") + def test_use_dsr_3_0(self, privacy_request): + assert use_dsr_3_0_scheduler(privacy_request, ActionType.access) is True + + @pytest.mark.usefixtures("use_dsr_3_0") + def test_use_dsr_2_0_override( + self, + privacy_request, + ): + cache: FidesopsRedis = get_cache() + key = f"access_request__test_dataset:test_collection" + cache.set_encoded_object(f"{privacy_request.id}__{key}", 2) + + # Privacy request already started processing on DSR 2.0 so we continue on DSR 2.0 + assert use_dsr_3_0_scheduler(privacy_request, ActionType.access) is False + + @pytest.mark.usefixtures("use_dsr_2_0") + def test_use_dsr_3_0_override(self, privacy_request, request_task): + # Privacy Request already started processing on 3.0, but we allow it to be switched to 2.0 + assert use_dsr_3_0_scheduler(privacy_request, ActionType.access) is False diff --git a/tests/ops/task/test_graph_task.py b/tests/ops/task/test_graph_task.py index adebf83aae..948b3e5e5c 100644 --- a/tests/ops/task/test_graph_task.py +++ b/tests/ops/task/test_graph_task.py @@ -1,3 +1,4 @@ +import uuid from typing import Any, Dict from unittest import mock from uuid import uuid4 @@ -27,16 +28,19 @@ from fides.api.models.privacy_request import ExecutionLog, ExecutionLogStatus from fides.api.models.sql_models import Dataset as CtlDataset from fides.api.schemas.policy import ActionType +from fides.api.task.deprecated_graph_task import ( + _evaluate_erasure_dependencies, + format_data_use_map_for_caching, + update_erasure_mapping_from_cache, +) from fides.api.task.graph_task import ( EMPTY_REQUEST, + EMPTY_REQUEST_TASK, GraphTask, TaskResources, - _evaluate_erasure_dependencies, - _format_data_use_map_for_caching, build_affected_field_logs, collect_queries, filter_by_enabled_actions, - update_erasure_mapping_from_cache, ) from fides.api.task.task_resources import Connections from fides.api.util.consent_util import ( @@ -79,12 +83,13 @@ def combined_traversal_node_dict(integration_mongodb_config, connection_config): @pytest.fixture(scope="function") def make_graph_task(integration_mongodb_config, connection_config, db): def task(node): + request_task = node.to_mock_request_task() return MockMongoTask( - node, TaskResources( EMPTY_REQUEST, Policy(), [connection_config, integration_mongodb_config], + request_task, db, ), ) @@ -96,9 +101,10 @@ class TestPreProcessInputData: def test_pre_process_input_data_scalar(self, db) -> None: t = sample_traversal() n = t.traversal_node_dict[CollectionAddress("mysql", "Address")] + request_task = n.to_mock_request_task() task = MockSqlTask( - n, TaskResources(EMPTY_REQUEST, Policy(), connection_configs, db) + TaskResources(EMPTY_REQUEST, Policy(), connection_configs, request_task, db) ) customers_data = [ {"contact_address_id": 31, "foo": "X"}, @@ -285,8 +291,9 @@ def test_pre_process_input_data_group_dependent_fields(self, db): n = traversal_with_grouped_inputs.traversal_node_dict[ CollectionAddress("mysql", "User") ] + request_task = n.to_mock_request_task() task = MockSqlTask( - n, TaskResources(EMPTY_REQUEST, Policy(), connection_configs, db) + TaskResources(EMPTY_REQUEST, Policy(), connection_configs, request_task, db) ) project_output = [ @@ -417,7 +424,9 @@ def test_sql_dry_run_queries(db) -> None: traversal = sample_traversal() env = collect_queries( traversal, - TaskResources(EMPTY_REQUEST, Policy(), connection_configs, db), + TaskResources( + EMPTY_REQUEST, Policy(), connection_configs, EMPTY_REQUEST_TASK, db + ), ) assert ( @@ -461,6 +470,7 @@ def test_mongo_dry_run_queries(db) -> None: key="postgres", connection_type=ConnectionType.mongodb ), ], + EMPTY_REQUEST_TASK, db, ), ) @@ -491,44 +501,50 @@ def node_fixture(self): ] dataset = postgres_order_node.node.dataset - field([dataset], "postgres", "Order", "customer_id").data_categories = ["A"] + field([dataset], "postgres", "Order", "customer_id").data_categories = [ + "user.name" + ] field([dataset], "postgres", "Order", "shipping_address_id").data_categories = [ - "B" + "system.operations" + ] + field([dataset], "postgres", "Order", "order_id").data_categories = [ + "system.operations" ] - field([dataset], "postgres", "Order", "order_id").data_categories = ["B"] field([dataset], "postgres", "Order", "billing_address_id").data_categories = [ - "C" + "user.contact" ] return postgres_order_node - def test_build_affected_field_logs(self, node_fixture): - policy = erasure_policy("A", "B") + def test_build_affected_field_logs(self, db, node_fixture): + policy = erasure_policy(db, "user.name", "system.operations") formatted_for_logs = build_affected_field_logs( node_fixture.node, policy, action_type=ActionType.erasure ) - # Only fields for data categories A and B which were specified on the Policy, made it to the logs for this node - assert formatted_for_logs == [ + # Only fields for data categories user.name and system.operations which were specified on the Policy, made it to the logs for this node + assert sorted(formatted_for_logs, key=lambda d: d["field_name"]) == [ { "path": "postgres:Order:customer_id", "field_name": "customer_id", - "data_categories": ["A"], + "data_categories": ["user.name"], }, { "path": "postgres:Order:order_id", "field_name": "order_id", - "data_categories": ["B"], + "data_categories": ["system.operations"], }, { "path": "postgres:Order:shipping_address_id", "field_name": "shipping_address_id", - "data_categories": ["B"], + "data_categories": ["system.operations"], }, ] - def test_build_affected_field_logs_no_data_categories_on_policy(self, node_fixture): - no_categories_policy = erasure_policy() + def test_build_affected_field_logs_no_data_categories_on_policy( + self, db, node_fixture + ): + no_categories_policy = erasure_policy(db) formatted_for_logs = build_affected_field_logs( node_fixture.node, no_categories_policy, @@ -537,8 +553,10 @@ def test_build_affected_field_logs_no_data_categories_on_policy(self, node_fixtu # No data categories specified on policy, so no fields affected assert formatted_for_logs == [] - def test_build_affected_field_logs_no_matching_data_categories(self, node_fixture): - d_categories_policy = erasure_policy("D") + def test_build_affected_field_logs_no_matching_data_categories( + self, db, node_fixture + ): + d_categories_policy = erasure_policy(db, "user.demographic") formatted_for_logs = build_affected_field_logs( node_fixture.node, d_categories_policy, @@ -548,9 +566,9 @@ def test_build_affected_field_logs_no_matching_data_categories(self, node_fixtur assert formatted_for_logs == [] def test_build_affected_field_logs_no_data_categories_for_action_type( - self, node_fixture + self, db, node_fixture ): - policy = erasure_policy("A", "B") + policy = erasure_policy(db, "user.name", "system.operations") formatted_for_logs = build_affected_field_logs( node_fixture.node, policy, @@ -559,38 +577,43 @@ def test_build_affected_field_logs_no_data_categories_for_action_type( # We only have data categories specified on an erasure policy, and we're looking for access action type assert formatted_for_logs == [] - def test_multiple_rules_targeting_same_field(self, node_fixture): - policy = erasure_policy("A") + def test_multiple_rules_targeting_same_field(self, db, node_fixture): + policy = erasure_policy(db, "user.name") - policy.rules = [ - Rule( - action_type=ActionType.erasure, - targets=[RuleTarget(data_category="A")], - masking_strategy={ - "strategy": "null_rewrite", - "configuration": {}, - }, - ), - Rule( - action_type=ActionType.erasure, - targets=[RuleTarget(data_category="A")], - masking_strategy={ - "strategy": "null_rewrite", - "configuration": {}, - }, - ), - ] + rule_1 = Rule( + action_type=ActionType.erasure, + targets=[RuleTarget(data_category="user.name")], + masking_strategy={ + "strategy": "null_rewrite", + "configuration": {}, + }, + policy_id=policy.id, + ) + + target_1 = RuleTarget(data_category="user.name", rule_id=rule_1.id) + + rule_2 = Rule( + action_type=ActionType.erasure, + targets=[RuleTarget(data_category="user.name")], + masking_strategy={ + "strategy": "null_rewrite", + "configuration": {}, + }, + policy_id=policy.id, + ) + + target_2 = RuleTarget(data_category="user.name", rule_id=rule_2.id) formatted_for_logs = build_affected_field_logs( node_fixture.node, policy, action_type=ActionType.erasure ) - # No duplication of the matching customer_id field, even though multiple rules targeted data category A + # No duplication of the matching customer_id field, even though multiple rules targeted data category user.name assert formatted_for_logs == [ { "path": "postgres:Order:customer_id", "field_name": "customer_id", - "data_categories": ["A"], + "data_categories": ["user.name"], } ] @@ -598,7 +621,7 @@ def test_multiple_rules_targeting_same_field(self, node_fixture): class TestUpdateErasureMappingFromCache: @pytest.fixture(scope="function") def task_resource(self, privacy_request, policy, db, connection_config): - tr = TaskResources(privacy_request, policy, [], db) + tr = TaskResources(privacy_request, policy, [], EMPTY_REQUEST_TASK, db) tr.get_connector = lambda x: Connections.build_connector(connection_config) return tr @@ -609,7 +632,8 @@ def collect_tasks_fn( ) -> None: """Run the traversal, as an action creating a GraphTask for each traversal_node.""" if not tn.is_root_node(): - data[tn.address] = GraphTask(tn, task_resource) + task_resource.privacy_request_task = tn.to_mock_request_task() + data[tn.address] = GraphTask(task_resource) return collect_tasks_fn @@ -726,7 +750,7 @@ def connection_config_no_system(self, db): data={ "name": str(uuid4()), "key": "connection_config_data_use_map_no_system", - "connection_type": ConnectionType.manual, + "connection_type": ConnectionType.timescale, "access": AccessLevel.write, "disabled": False, }, @@ -749,7 +773,7 @@ def connection_config_system(self, db, system): data={ "name": str(uuid4()), "key": "connection_config_data_use_map", - "connection_type": ConnectionType.manual, + "connection_type": ConnectionType.timescale, "access": AccessLevel.write, "disabled": False, "system_id": system.id, @@ -776,7 +800,7 @@ def connection_config_system_multiple_decs(self, db, system_multiple_decs): data={ "name": str(uuid4()), "key": "connection_config_data_use_map_system_multiple_decs", - "connection_type": ConnectionType.manual, + "connection_type": ConnectionType.timescale, "access": AccessLevel.write, "disabled": False, "system_id": system_multiple_decs.id, @@ -801,7 +825,9 @@ def connection_config_system_multiple_decs(self, db, system_multiple_decs): [ "connection_config_no_system" ], # connection config no system, no data uses - {"postgres_example_subscriptions_dataset_no_system:subscriptions": {}}, + { + "postgres_example_subscriptions_dataset_no_system:subscriptions": set() + }, ), ( [ @@ -830,7 +856,7 @@ def connection_config_system_multiple_decs(self, db, system_multiple_decs): "connection_config_system_multiple_decs", ], { - "postgres_example_subscriptions_dataset_no_system:subscriptions": {}, + "postgres_example_subscriptions_dataset_no_system:subscriptions": set(), "postgres_example_subscriptions_dataset_multiple_decs:subscriptions": { "marketing.advertising", "third_party_sharing", @@ -866,7 +892,9 @@ def test_data_use_map( dataset_graph, {"email": {"test_user@example.com"}} ) env: Dict[CollectionAddress, Any] = {} - task_resources = TaskResources(privacy_request, policy, connection_configs, db) + task_resources = TaskResources( + privacy_request, policy, connection_configs, EMPTY_REQUEST_TASK, db + ) # perform the traversal to populate our `env` dict def collect_tasks_fn( @@ -874,7 +902,8 @@ def collect_tasks_fn( ) -> None: """Run the traversal, as an action creating a GraphTask for each traversal_node.""" if not tn.is_root_node(): - data[tn.address] = GraphTask(tn, task_resources) + task_resources.privacy_request_task = tn.to_mock_request_task() + data[tn.address] = GraphTask(task_resources) traversal.traverse( env, @@ -882,7 +911,16 @@ def collect_tasks_fn( ) # ensure that the generated data_use_map looks as expected based on `env` dict - assert _format_data_use_map_for_caching(env) == expected_data_use_map + assert ( + format_data_use_map_for_caching( + { + coll_address: gt.execution_node.connection_key + for (coll_address, gt) in env.items() + }, + connection_configs, + ) + == expected_data_use_map + ) class TestGraphTaskAffectedConsentSystems: @@ -897,13 +935,15 @@ def mock_graph_task( privacy_request_with_consent_policy, privacy_request_with_consent_policy.policy, [mailchimp_transactional_connection_config_no_secrets], + EMPTY_REQUEST_TASK, db, ) tn = TraversalNode(generate_node("a", "b", "c", "c2")) tn.node.dataset.connection_key = ( mailchimp_transactional_connection_config_no_secrets.key ) - return GraphTask(tn, task_resources) + task_resources.privacy_request_task = tn.to_mock_request_task() + return GraphTask(task_resources) @mock.patch( "fides.api.service.connectors.saas_connector.SaaSConnector.run_consent_request" @@ -954,12 +994,14 @@ def test_skipped_consent_task_for_connector( ) assert logs.first().status == ExecutionLogStatus.skipped + @mock.patch("fides.api.task.graph_task.mark_current_and_downstream_nodes_as_failed") @mock.patch( "fides.api.service.connectors.saas_connector.SaaSConnector.run_consent_request" ) def test_errored_consent_task_for_connector_no_relevant_preferences( self, mock_run_consent_request, + mark_current_and_downstream_nodes_as_failed_mock, mailchimp_transactional_connection_config_no_secrets, mock_graph_task, db, @@ -1008,6 +1050,7 @@ def test_errored_consent_task_for_connector_no_relevant_preferences( .order_by(ExecutionLog.created_at.desc()) ) assert logs.first().status == ExecutionLogStatus.error + assert mark_current_and_downstream_nodes_as_failed_mock.called class TestFilterByEnabledActions: @@ -1098,3 +1141,137 @@ def test_filter_by_enabled_actions_mixed_actions(self): assert filter_by_enabled_actions(access_results, connection_configs) == { "dataset1:collection": "data", } + + +class TestGraphTaskLogging: + @pytest.fixture(scope="function") + def graph_task(self, privacy_request, policy, db): + resources = TaskResources( + privacy_request, + policy, + [ + ConnectionConfig( + key="mock_connection_config_key_a", + connection_type=ConnectionType.postgres, + ) + ], + EMPTY_REQUEST_TASK, + db, + ) + tn = TraversalNode(generate_node("a", "b", "c")) + rq = tn.to_mock_request_task() + rq.action_type = ActionType.access + rq.status = ExecutionLogStatus.pending + rq.id = str(uuid.uuid4()) + db.add(rq) + db.commit() + + resources.privacy_request_task = rq + return GraphTask(resources) + + def test_log_start(self, graph_task, db, privacy_request): + graph_task.log_start(action_type=ActionType.access) + + assert graph_task.request_task.status == ExecutionLogStatus.in_processing + + execution_log = ( + db.query(ExecutionLog) + .filter( + ExecutionLog.privacy_request_id == privacy_request.id, + ExecutionLog.collection_name == "b", + ExecutionLog.dataset_name == "a", + ExecutionLog.action_type == ActionType.access, + ) + .first() + ) + assert execution_log.status == ExecutionLogStatus.in_processing + + def test_log_retry(self, graph_task, db, privacy_request): + graph_task.log_retry(action_type=ActionType.access) + + assert graph_task.request_task.status == ExecutionLogStatus.retrying + + execution_log = ( + db.query(ExecutionLog) + .filter( + ExecutionLog.privacy_request_id == privacy_request.id, + ExecutionLog.collection_name == "b", + ExecutionLog.dataset_name == "a", + ExecutionLog.action_type == ActionType.access, + ) + .first() + ) + assert execution_log.status == ExecutionLogStatus.retrying + + def test_log_skipped(self, graph_task, db, privacy_request): + graph_task.log_skipped(action_type=ActionType.access, ex="Skipping node") + + assert graph_task.request_task.status == ExecutionLogStatus.skipped + assert graph_task.request_task.consent_sent is None, "Not applicable for access" + + execution_log = ( + db.query(ExecutionLog) + .filter( + ExecutionLog.privacy_request_id == privacy_request.id, + ExecutionLog.collection_name == "b", + ExecutionLog.dataset_name == "a", + ExecutionLog.action_type == ActionType.access, + ) + .first() + ) + assert execution_log.status == ExecutionLogStatus.skipped + + graph_task.log_skipped(action_type=ActionType.consent, ex="Skipping node") + assert graph_task.request_task.consent_sent is False + + @mock.patch("fides.api.task.graph_task.mark_current_and_downstream_nodes_as_failed") + def test_log_end_error( + self, + mark_current_and_downstream_nodes_as_failed_mock, + graph_task, + db, + privacy_request, + ): + graph_task.log_end(action_type=ActionType.access, ex=Exception("Key Error")) + + assert graph_task.request_task.status == ExecutionLogStatus.error + + assert mark_current_and_downstream_nodes_as_failed_mock.called + execution_log = ( + db.query(ExecutionLog) + .filter( + ExecutionLog.privacy_request_id == privacy_request.id, + ExecutionLog.collection_name == "b", + ExecutionLog.dataset_name == "a", + ExecutionLog.action_type == ActionType.access, + ) + .first() + ) + + assert execution_log.status == ExecutionLogStatus.error + + @mock.patch("fides.api.task.graph_task.mark_current_and_downstream_nodes_as_failed") + def test_log_end_complete( + self, + mark_current_and_downstream_nodes_as_failed_mock, + graph_task, + db, + privacy_request, + ): + graph_task.log_end(action_type=ActionType.access) + + assert graph_task.request_task.status == ExecutionLogStatus.complete + + assert not mark_current_and_downstream_nodes_as_failed_mock.called + execution_log = ( + db.query(ExecutionLog) + .filter( + ExecutionLog.privacy_request_id == privacy_request.id, + ExecutionLog.collection_name == "b", + ExecutionLog.dataset_name == "a", + ExecutionLog.action_type == ActionType.access, + ) + .first() + ) + + assert execution_log.status == ExecutionLogStatus.complete diff --git a/tests/ops/task/test_task_resources.py b/tests/ops/task/test_task_resources.py index d5a13794c4..7303274a2c 100644 --- a/tests/ops/task/test_task_resources.py +++ b/tests/ops/task/test_task_resources.py @@ -1,10 +1,13 @@ +from fides.api.task.graph_task import EMPTY_REQUEST_TASK from fides.api.task.task_resources import TaskResources class TestTaskResources: def test_cache_object(self, db, privacy_request, policy, integration_manual_config): + # DSR 3.0 introduced RequestTasks, but you can pass in an empty Request Task for DSR 2.0 + # or for testing into TaskResources resources = TaskResources( - privacy_request, policy, [integration_manual_config], db + privacy_request, policy, [integration_manual_config], EMPTY_REQUEST_TASK, db ) assert resources.get_all_cached_objects() == {} @@ -37,7 +40,7 @@ def test_cache_erasure( self, db, privacy_request, policy, integration_manual_config ): resources = TaskResources( - privacy_request, policy, [integration_manual_config], db + privacy_request, policy, [integration_manual_config], EMPTY_REQUEST_TASK, db ) assert resources.get_all_cached_erasures() == {} diff --git a/tests/ops/tasks/test_scheduled.py b/tests/ops/tasks/test_scheduled.py index 9f85ab8f2c..dd25f06c93 100644 --- a/tests/ops/tasks/test_scheduled.py +++ b/tests/ops/tasks/test_scheduled.py @@ -1,6 +1,8 @@ -import pytest +import datetime + from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.date import DateTrigger +from apscheduler.triggers.interval import IntervalTrigger from fides.api.models.privacy_request import PrivacyRequestStatus from fides.api.service.privacy_request.email_batch_service import ( @@ -10,6 +12,12 @@ from fides.api.service.privacy_request.request_runner_service import ( initiate_paused_privacy_request_followup, ) +from fides.api.service.privacy_request.request_service import ( + DSR_DATA_REMOVAL, + PRIVACY_REQUEST_STATUS_CHANGE_POLL, + initiate_poll_for_exited_privacy_request_tasks, + initiate_scheduled_dsr_data_removal, +) from fides.api.tasks.scheduled.scheduler import scheduler from fides.config import get_config @@ -46,3 +54,42 @@ def test_initiate_batch_email_send() -> None: assert type(job.trigger.timezone).__name__ == "US/Eastern" CONFIG.test_mode = True + + +def test_initiate_scheduled_dsr_data_removal() -> None: + """Currently runs weekly to pick up any Request Tasks that expired in the last week + or Privacy Requests that need PII removed from them.""" + CONFIG.test_mode = False + + initiate_scheduled_dsr_data_removal() + assert scheduler.running + job = scheduler.get_job(job_id=DSR_DATA_REMOVAL) + assert job is not None + assert isinstance(job.trigger, CronTrigger) + + assert job.trigger.fields[2].name == "day" + assert str(job.trigger.fields[2].expressions[0]) == "*" + + assert job.trigger.fields[5].name == "hour" + assert str(job.trigger.fields[5].expressions[0]) == "2" + + assert type(job.trigger.timezone).__name__ == "US/Eastern" + + CONFIG.test_mode = True + + +def test_initiate_poll_for_exited_privacy_request_tasks() -> None: + """This task runs on an interval looking for Privacy Requests that need to change state + because all their Request Tasks have had a chance to run but some are errored""" + CONFIG.test_mode = False + + initiate_poll_for_exited_privacy_request_tasks() + assert scheduler.running + job = scheduler.get_job(job_id=PRIVACY_REQUEST_STATUS_CHANGE_POLL) + assert job is not None + assert isinstance(job.trigger, IntervalTrigger) + assert job.trigger.interval == datetime.timedelta( + seconds=CONFIG.execution.state_polling_interval + ) + + CONFIG.test_mode = True diff --git a/tests/ops/test_helpers/cache_secrets_helper.py b/tests/ops/test_helpers/cache_secrets_helper.py index ea7f8e4193..c0694d9865 100644 --- a/tests/ops/test_helpers/cache_secrets_helper.py +++ b/tests/ops/test_helpers/cache_secrets_helper.py @@ -17,3 +17,12 @@ def cache_secret(masking_secret_cache: MaskingSecretCache, request_id: str) -> N def clear_cache_secrets(request_id: str) -> None: cache: FidesopsRedis = get_cache() cache.delete_keys_by_prefix(f"id-{request_id}-masking-secret)") + + +def clear_cache_identities(request_id: str) -> None: + """Testing helper just removes some cached identities from the Privacy Request for testing. + + Some of our Privacy Request fixtures automatically cache identities - + """ + cache: FidesopsRedis = get_cache() + cache.delete_keys_by_prefix(f"id-{request_id}-identity-") diff --git a/tests/ops/util/test_cache.py b/tests/ops/util/test_cache.py index 9f7384ef96..c23510ffc5 100644 --- a/tests/ops/util/test_cache.py +++ b/tests/ops/util/test_cache.py @@ -4,6 +4,7 @@ from datetime import datetime from enum import Enum from typing import Any, List +from unittest import mock import pytest from bson.objectid import ObjectId @@ -13,6 +14,8 @@ ENCODED_DATE_PREFIX, ENCODED_MONGO_OBJECT_ID_PREFIX, FidesopsRedis, + cache_task_tracking_key, + celery_tasks_in_flight, ) from fides.config import CONFIG from tests.fixtures.application_fixtures import faker @@ -190,3 +193,49 @@ def test_decode_pickle_doesnt_break(self): value = b64encode(pickle.dumps(PickleObj())) cache = FidesopsRedis() assert cache.decode_obj(value) is None + + +class TestCacheTaskTrackingKey: + def test_cache_tracking_key_privacy_request(self, privacy_request): + assert privacy_request.get_cached_task_id() is None + + cache_task_tracking_key(privacy_request.id, "test_1234") + + assert privacy_request.get_cached_task_id() == "test_1234" + + def test_cache_tracking_key_request_task(self, request_task): + """Request Task celery tasks are cached in the same location as Privacy Request""" + assert request_task.get_cached_task_id() is None + + cache_task_tracking_key(request_task.id, "test_5678") + + assert request_task.get_cached_task_id() == "test_5678" + + +class TestCeleryTasksInFlight: + def test_celery_tasks_in_flight_no_celery_tasks(self): + assert not celery_tasks_in_flight([]) + + @mock.patch("fides.api.util.cache.celery_app.control.inspect.query_task") + def test_celery_tasks_in_flight_no_workers(self, query_task_mock): + query_task_mock.return_value = {} + + assert not celery_tasks_in_flight(["1234"]) + + @mock.patch("fides.api.util.cache.celery_app.control.inspect.query_task") + def test_celery_tasks_in_flight_no_match_in_queue(self, query_task_mock): + query_task_mock.return_value = {"@celery1234": {}} + + assert not celery_tasks_in_flight(["abcde"]) + + @mock.patch("fides.api.util.cache.celery_app.control.inspect.query_task") + def test_celery_tasks_in_flight_completed_state(self, query_task_mock): + query_task_mock.return_value = {"@celery1234": {"abcde": ["completed", {}]}} + + assert not celery_tasks_in_flight(["abde"]) + + @mock.patch("fides.api.util.cache.celery_app.control.inspect.query_task") + def test_celery_tasks_in_flight_reserved_state(self, query_task_mock): + query_task_mock.return_value = {"@celery1234": {"abcde": ["reserved", {}]}} + + assert celery_tasks_in_flight(["abde"])