diff --git a/data/dataset/mongo_example_test_dataset.yml b/data/dataset/mongo_example_test_dataset.yml index ec1b0c567..064f61651 100644 --- a/data/dataset/mongo_example_test_dataset.yml +++ b/data/dataset/mongo_example_test_dataset.yml @@ -7,6 +7,8 @@ dataset: fields: - name: _id data_categories: [system.operations] + fidesops_meta: + primary_key: True - name: customer_id data_categories: [user.derived.identifiable.unique_id] fidesops_meta: diff --git a/src/fidesops/graph/config.py b/src/fidesops/graph/config.py index b452a52fa..ab8ec759d 100644 --- a/src/fidesops/graph/config.py +++ b/src/fidesops/graph/config.py @@ -77,6 +77,8 @@ """ from __future__ import annotations +from collections import defaultdict + from typing import List, Optional, Tuple, Set, Dict, Literal from pydantic import BaseModel @@ -223,6 +225,26 @@ def identities(self) -> Dict[str, Tuple[str, ...]]: flds_w_ident = filter(lambda f: f.identity, self.fields) return {f.name: f.identity for f in flds_w_ident} + @property + def fields_by_category(self) -> Dict[str, List]: + """Returns mapping of data categories to fields, flips fields -> categories + to be categories -> fields. + + Example: + { + "user.provided.identifiable.contact.city": ["city"], + "user.provided.identifiable.contact.street": ["house", "street"], + "system.operations": ["id"], + "user.provided.identifiable.contact.state": ["state"], + "user.provided.identifiable.contact.postal_code": ["zip"] + } + """ + categories = defaultdict(list) + for field in self.fields: + for category in field.data_categories or []: + categories[category].append(field.name) + return categories + class Config: """for pydantic incorporation of custom non-pydantic types""" diff --git a/src/fidesops/graph/graph.py b/src/fidesops/graph/graph.py index be27c6abd..d5bb730fc 100644 --- a/src/fidesops/graph/graph.py +++ b/src/fidesops/graph/graph.py @@ -250,9 +250,7 @@ def data_category_field_mapping(self) -> Dict[str, Dict[str, List]]: """ mapping: Dict[str, Dict[str, List]] = defaultdict(lambda: defaultdict(list)) for node_address, node in self.nodes.items(): - for field in node.collection.fields: - for category in field.data_categories: - mapping[str(node_address)][category].append(field.name) + mapping[str(node_address)] = node.collection.fields_by_category return mapping def __repr__(self) -> str: diff --git a/src/fidesops/models/policy.py b/src/fidesops/models/policy.py index 41ba5c3a0..60f366a90 100644 --- a/src/fidesops/models/policy.py +++ b/src/fidesops/models/policy.py @@ -38,6 +38,7 @@ from fidesops.service.masking.strategy.masking_strategy_factory import ( SupportedMaskingStrategies, ) +from fidesops.service.masking.strategy.masking_strategy_nullify import NULL_REWRITE class ActionType(EnumType): @@ -100,6 +101,16 @@ def _validate_rule( "Erasure Rules must have masking strategies." ) + # Temporary, remove when we have the pieces in place to support more than null masking. + if ( + action_type == ActionType.erasure.value + and masking_strategy + and masking_strategy.get("strategy") != NULL_REWRITE + ): + raise common_exceptions.RuleValidationError( + "Only the Null Masking Strategy (null_rewrite) is supported at this time." + ) + if action_type == ActionType.access.value and storage_destination_id is None: raise common_exceptions.RuleValidationError( "Access Rules must have a storage destination." diff --git a/src/fidesops/service/connectors/query_config.py b/src/fidesops/service/connectors/query_config.py index 3c9056dfb..b70bd76c8 100644 --- a/src/fidesops/service/connectors/query_config.py +++ b/src/fidesops/service/connectors/query_config.py @@ -8,7 +8,8 @@ from fidesops.graph.config import ROOT_COLLECTION_ADDRESS, CollectionAddress from fidesops.graph.traversal import TraversalNode, Row -from fidesops.models.policy import Policy +from fidesops.models.policy import Policy, ActionType, Rule +from fidesops.service.masking.strategy.masking_strategy_factory import get_strategy from fidesops.util.collection_util import append logging.basicConfig(level=logging.INFO) @@ -36,29 +37,29 @@ def fields(self) -> List[str]: """Fields of interest from this traversal traversal_node.""" return [f.name for f in self.node.node.collection.fields] - def update_fields(self, policy: Policy) -> List[str]: - """List of update-able field names""" - - def exists_child( - field_categories: List[str], policy_categories: List[str] - ) -> bool: - """A not very efficient check for any policy category that matches one of the field categories or a prefix of it.""" - if field_categories is None or len(field_categories) == 0: - return False - for policy_category in policy_categories: - for field_category in field_categories: - if field_category.startswith(policy_category): - return True - - return False - - policy_categories = policy.get_erasure_target_categories() - - return [ - f.name - for f in self.node.node.collection.fields - if exists_child(f.data_categories, policy_categories) - ] + def build_rule_target_fields(self, policy: Policy) -> Dict[Rule, List[str]]: + """ + Return dictionary of rules mapped to update-able field names on a given collection + Example: + {: ['name', 'code', 'ccn']} + """ + rule_updates: Dict[Rule, List[str]] = {} + for rule in policy.rules: + if rule.action_type != ActionType.erasure: + continue + rule_categories = rule.get_target_data_categories() + if not rule_categories: + continue + + targeted_fields = [] + collection_categories = self.node.node.collection.fields_by_category + for rule_cat in rule_categories: + for collection_cat, fields in collection_categories.items(): + if collection_cat.startswith(rule_cat): + targeted_fields.extend(fields) + rule_updates[rule] = targeted_fields + + return rule_updates @property def primary_keys(self) -> List[str]: @@ -116,6 +117,28 @@ def display_query_data(self) -> Dict[str, Any]: return data + def update_value_map(self, row: Row, policy: Policy) -> Dict[str, Any]: + """Map the relevant fields to be updated on the row with their masked values from Policy Rules + + Example return: {'name': None, 'ccn': None, 'code': None} + + In this example, a Null Masking Strategy was used to determine that the name/ccn/code fields + for a given customer_id will be replaced with null values. + + """ + rule_to_collection_fields = self.build_rule_target_fields(policy) + + value_map: Dict[str, Any] = {} + for rule, field_names in rule_to_collection_fields.items(): + strategy_config = rule.masking_strategy + strategy = get_strategy( + strategy_config["strategy"], strategy_config["configuration"] + ) + + for field_name in field_names: + value_map[field_name] = strategy.mask(row[field_name]) + return value_map + @abstractmethod def generate_query( self, input_data: Dict[str, List[Any]], policy: Optional[Policy] @@ -172,16 +195,14 @@ def generate_query( ) return None - def generate_update_stmt( - self, row: Row, policy: Optional[Policy] = None - ) -> Optional[TextClause]: - """Generate a SQL update statement in the form of a TextClause""" - update_fields = self.update_fields(policy) - update_value_map = {k: None for k in update_fields} - update_clauses = [f"{k} = :{k}" for k in update_fields] + def generate_update_stmt(self, row: Row, policy: Policy) -> Optional[TextClause]: + update_value_map = self.update_value_map(row, policy) + update_clauses = [f"{k} = :{k}" for k in update_value_map] pk_clauses = [f"{k} = :{k}" for k in self.primary_keys] + for pk in self.primary_keys: update_value_map[pk] = row[pk] + valid = len(pk_clauses) > 0 and len(update_clauses) > 0 if not valid: logger.warning( @@ -276,8 +297,7 @@ def generate_update_stmt( self, row: Row, policy: Optional[Policy] = None ) -> Optional[MongoStatement]: """Generate a SQL update statement in the form of Mongo update statement components""" - update_fields = self.update_fields(policy) - update_clauses = {k: None for k in update_fields} + update_clauses = self.update_value_map(row, policy) pk_clauses = {k: row[k] for k in self.primary_keys} valid = len(pk_clauses) > 0 and len(update_clauses) > 0 diff --git a/tests/api/v1/endpoints/test_policy_endpoints.py b/tests/api/v1/endpoints/test_policy_endpoints.py index b0ff38482..69498bd4a 100644 --- a/tests/api/v1/endpoints/test_policy_endpoints.py +++ b/tests/api/v1/endpoints/test_policy_endpoints.py @@ -23,6 +23,7 @@ generate_fides_data_categories, ) from fidesops.service.masking.strategy.masking_strategy_hash import HASH +from fidesops.service.masking.strategy.masking_strategy_nullify import NULL_REWRITE class TestGetPolicies: @@ -451,18 +452,14 @@ def test_create_erasure_rule_for_policy( generate_auth_header, policy, ): - FORMAT_PRESERVATION_SUFFIX = "@masked.com" - HASH_ALGORITHM = "SHA-512" + data = [ { "name": "test erasure rule", "action_type": ActionType.erasure.value, "masking_strategy": { - "strategy": HASH, - "configuration": { - "algorithm": HASH_ALGORITHM, - "format_preservation": {"suffix": FORMAT_PRESERVATION_SUFFIX}, - }, + "strategy": NULL_REWRITE, + "configuration": {}, }, } ] @@ -479,7 +476,7 @@ def test_create_erasure_rule_for_policy( rule_data = response_data[0] assert "masking_strategy" in rule_data masking_strategy_data = rule_data["masking_strategy"] - assert masking_strategy_data["strategy"] == HASH + assert masking_strategy_data["strategy"] == NULL_REWRITE assert "configuration" not in masking_strategy_data def test_update_rule_policy_id_fails( @@ -822,8 +819,8 @@ def test_create_conflicting_rule_targets( "name": "Erasure Rule", "policy_id": policy.id, "masking_strategy": { - "strategy": HASH, - "configuration": {"algorithm": "SHA-512"}, + "strategy": NULL_REWRITE, + "configuration": {}, }, }, ) diff --git a/tests/fixtures.py b/tests/fixtures.py index f263d6719..ca0f27082 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -39,6 +39,8 @@ StorageSecrets, StorageType, ) +from fidesops.service.masking.strategy.masking_strategy_nullify import NULL_REWRITE +from fidesops.service.masking.strategy.masking_strategy_string_rewrite import STRING_REWRITE from fidesops.util.cache import FidesopsRedis logging.getLogger("faker").setLevel(logging.ERROR) @@ -222,8 +224,8 @@ def erasure_policy( "name": "Erasure Rule", "policy_id": erasure_policy.id, "masking_strategy": { - "strategy": "hash", - "configuration": {"algorithm": "SHA-512"}, + "strategy": "null_rewrite", + "configuration": {}, }, }, ) @@ -251,6 +253,49 @@ def erasure_policy( pass +@pytest.fixture(scope="function") +def erasure_policy_two_rules(db: Session, oauth_client: ClientDetail, erasure_policy: Policy) -> Generator: + + second_erasure_rule = Rule.create( + db=db, + data={ + "action_type": ActionType.erasure.value, + "client_id": oauth_client.id, + "name": "Second Erasure Rule", + "policy_id": erasure_policy.id, + "masking_strategy": {"strategy": NULL_REWRITE, "configuration": {}}, + }, + ) + + # TODO set masking strategy in Rule.create() call above, once more masking strategies beyond NULL_REWRITE are supported. + second_erasure_rule.masking_strategy = { + "strategy": STRING_REWRITE, + "configuration": {"rewrite_value": "*****"} + } + + second_rule_target = RuleTarget.create( + db=db, + data={ + "client_id": oauth_client.id, + "data_category": DataCategory("user.provided.identifiable.contact.email").value, + "rule_id": second_erasure_rule.id, + }, + ) + yield erasure_policy + try: + second_rule_target.delete(db) + except ObjectDeletedError: + pass + try: + second_erasure_rule.delete(db) + except ObjectDeletedError: + pass + try: + erasure_policy.delete(db) + except ObjectDeletedError: + pass + + @pytest.fixture(scope="function") def policy( db: Session, diff --git a/tests/graph/graph_test_util.py b/tests/graph/graph_test_util.py index 98c1e1627..15cb6d5a6 100644 --- a/tests/graph/graph_test_util.py +++ b/tests/graph/graph_test_util.py @@ -15,6 +15,7 @@ from fidesops.models.privacy_request import PrivacyRequest from fidesops.service.connectors import BaseConnector from fidesops.service.connectors.sql_connector import SQLConnector +from fidesops.service.masking.strategy.masking_strategy_nullify import NullMaskingStrategy from fidesops.task.graph_task import GraphTask from fidesops.task.task_resources import TaskResources from ..fixtures import faker @@ -63,7 +64,10 @@ def erasure_policy(*erasure_categories: str) -> Policy: """Generate an erasure policy with the given categories""" policy = Policy() targets = [RuleTarget(data_category=c) for c in erasure_categories] - policy.rules = [Rule(action_type=ActionType.erasure, targets=targets)] + policy.rules = [Rule(action_type=ActionType.erasure, targets=targets, masking_strategy={ + "strategy": "null_rewrite", + "configuration": {}, + })] return policy diff --git a/tests/models/test_policy.py b/tests/models/test_policy.py index 05e7ee80f..e61db2614 100644 --- a/tests/models/test_policy.py +++ b/tests/models/test_policy.py @@ -16,6 +16,7 @@ _is_ancestor_of_contained_categories, ) from fidesops.service.masking.strategy.masking_strategy_hash import HASH +from fidesops.service.masking.strategy.masking_strategy_nullify import NULL_REWRITE from fidesops.util.text import slugify @@ -211,11 +212,8 @@ def test_create_erasure_rule( "name": "Valid Erasure Rule", "policy_id": policy.id, "masking_strategy": { - "strategy": HASH, - "configuration": { - "algorithm": "SHA-512", - "format_preservation": {"suffix": "@masked.com"}, - }, + "strategy": NULL_REWRITE, + "configuration": {}, }, }, ) @@ -313,8 +311,8 @@ def test_validate_policy( "name": "Erasure Rule", "policy_id": erasure_policy.id, "masking_strategy": { - "strategy": HASH, - "configuration": {"algorithm": "SHA-512"}, + "strategy": NULL_REWRITE, + "configuration": {}, }, }, ) @@ -336,8 +334,8 @@ def test_validate_policy( "name": "Another Erasure Rule", "policy_id": erasure_policy.id, "masking_strategy": { - "strategy": HASH, - "configuration": {"algorithm": "SHA-512"}, + "strategy": NULL_REWRITE, + "configuration": {}, }, }, ) @@ -354,4 +352,19 @@ def test_validate_policy( }, ) + with pytest.raises(RuleValidationError): + Rule.create( + db=db, + data={ + "action_type": ActionType.erasure.value, + "client_id": oauth_client.id, + "name": "Erasure Rule with unsupported masking strategy", + "policy_id": erasure_policy.id, + "masking_strategy": { + "strategy": HASH, + "configuration": {"algorithm": "SHA-512"}, + }, + }, + ) + erasure_policy.delete(db=db) # This will tear down everything created here diff --git a/tests/service/connectors/test_queryconfig.py b/tests/service/connectors/test_queryconfig.py index 90a9ac5eb..5a415c23b 100644 --- a/tests/service/connectors/test_queryconfig.py +++ b/tests/service/connectors/test_queryconfig.py @@ -1,11 +1,19 @@ +import pytest from typing import Dict, Any, Set from fidesops.graph.config import CollectionAddress from fidesops.graph.graph import DatasetGraph from fidesops.graph.traversal import Traversal, TraversalNode -from fidesops.models.connectionconfig import ConnectionConfig -from fidesops.service.connectors import PostgreSQLConnector -from fidesops.service.connectors.query_config import QueryConfig, SQLQueryConfig +from fidesops.models.datasetconfig import convert_dataset_to_graph +from fidesops.models.policy import DataCategory +from fidesops.schemas.dataset import FidesopsDataset +from fidesops.schemas.masking.masking_configuration import HashMaskingConfiguration +from fidesops.service.connectors.query_config import ( + QueryConfig, + SQLQueryConfig, + MongoQueryConfig, +) +from fidesops.service.masking.strategy.masking_strategy_hash import HashMaskingStrategy from ...task.traversal_data import integration_db_graph @@ -22,82 +30,319 @@ payment_card_node = traversal_nodes[ CollectionAddress("postgres_example", "payment_card") ] +user_node = traversal_nodes[CollectionAddress("postgres_example", "payment_card")] -def test_extract_query_components(): - def found_query_keys(qconfig: QueryConfig, values: Dict[str, Any]) -> Set[str]: - return set(qconfig.filter_values(values).keys()) +class TestSQLQueryConfig: + def test_extract_query_components(self): + def found_query_keys(qconfig: QueryConfig, values: Dict[str, Any]) -> Set[str]: + return set(qconfig.filter_values(values).keys()) - config = SQLQueryConfig(payment_card_node) - assert config.fields == ["id", "name", "ccn", "customer_id", "billing_address_id"] - assert config.query_keys == {"id", "customer_id"} + config = SQLQueryConfig(payment_card_node) + assert config.fields == [ + "id", + "name", + "ccn", + "customer_id", + "billing_address_id", + ] + assert config.query_keys == {"id", "customer_id"} - # values exist for all query keys - assert found_query_keys( - config, {"id": ["A"], "customer_id": ["V"], "ignore_me": ["X"]} - ) == {"id", "customer_id"} - # with no values OR an empty set, these are omitted - assert found_query_keys( - config, {"id": ["A"], "customer_id": [], "ignore_me": ["X"]} - ) == {"id"} - assert found_query_keys(config, {"id": ["A"], "ignore_me": ["X"]}) == {"id"} - assert found_query_keys(config, {"ignore_me": ["X"]}) == set() - assert found_query_keys(config, {}) == set() + # values exist for all query keys + assert found_query_keys( + config, {"id": ["A"], "customer_id": ["V"], "ignore_me": ["X"]} + ) == {"id", "customer_id"} + # with no values OR an empty set, these are omitted + assert found_query_keys( + config, {"id": ["A"], "customer_id": [], "ignore_me": ["X"]} + ) == {"id"} + assert found_query_keys(config, {"id": ["A"], "ignore_me": ["X"]}) == {"id"} + assert found_query_keys(config, {"ignore_me": ["X"]}) == set() + assert found_query_keys(config, {}) == set() + def test_filter_values(self): + config = SQLQueryConfig(payment_card_node) + assert config.filter_values( + {"id": ["A"], "customer_id": ["V"], "ignore_me": ["X"]} + ) == {"id": ["A"], "customer_id": ["V"]} -def test_filter_values(): - config = SQLQueryConfig(payment_card_node) - assert config.filter_values( - {"id": ["A"], "customer_id": ["V"], "ignore_me": ["X"]} - ) == {"id": ["A"], "customer_id": ["V"]} + assert config.filter_values( + {"id": ["A"], "customer_id": [], "ignore_me": ["X"]} + ) == {"id": ["A"]} - assert config.filter_values( - {"id": ["A"], "customer_id": [], "ignore_me": ["X"]} - ) == {"id": ["A"]} + assert config.filter_values({"id": ["A"], "ignore_me": ["X"]}) == {"id": ["A"]} - assert config.filter_values({"id": ["A"], "ignore_me": ["X"]}) == {"id": ["A"]} + assert config.filter_values({"id": [], "customer_id": ["V"]}) == { + "customer_id": ["V"] + } - assert config.filter_values({"id": [], "customer_id": ["V"]}) == { - "customer_id": ["V"] - } - - -def test_generated_sql_query(): - """Test that the generated query depends on the input set""" - postgresql_connector = PostgreSQLConnector(ConnectionConfig()) - - assert ( - str( - SQLQueryConfig(payment_card_node).generate_query( - {"id": ["A"], "customer_id": ["V"], "ignore_me": ["X"]} + def test_generated_sql_query(self): + """Test that the generated query depends on the input set""" + assert ( + str( + SQLQueryConfig(payment_card_node).generate_query( + {"id": ["A"], "customer_id": ["V"], "ignore_me": ["X"]} + ) ) + == "SELECT id,name,ccn,customer_id,billing_address_id FROM payment_card WHERE id = :id OR customer_id = :customer_id" ) - == "SELECT id,name,ccn,customer_id,billing_address_id FROM payment_card WHERE id = :id OR customer_id = :customer_id" - ) - assert ( - str( - SQLQueryConfig(payment_card_node).generate_query( - {"id": ["A"], "customer_id": [], "ignore_me": ["X"]} + assert ( + str( + SQLQueryConfig(payment_card_node).generate_query( + {"id": ["A"], "customer_id": [], "ignore_me": ["X"]} + ) ) + == "SELECT id,name,ccn,customer_id,billing_address_id FROM payment_card WHERE id = :id" ) - == "SELECT id,name,ccn,customer_id,billing_address_id FROM payment_card WHERE id = :id" - ) - assert ( - str( - SQLQueryConfig(payment_card_node).generate_query( - {"id": ["A"], "ignore_me": ["X"]} + assert ( + str( + SQLQueryConfig(payment_card_node).generate_query( + {"id": ["A"], "ignore_me": ["X"]} + ) ) + == "SELECT id,name,ccn,customer_id,billing_address_id FROM payment_card WHERE id = :id" ) - == "SELECT id,name,ccn,customer_id,billing_address_id FROM payment_card WHERE id = :id" - ) - assert ( - str( - SQLQueryConfig(payment_card_node).generate_query( - {"id": [], "customer_id": ["V"]} + assert ( + str( + SQLQueryConfig(payment_card_node).generate_query( + {"id": [], "customer_id": ["V"]} + ) ) + == "SELECT id,name,ccn,customer_id,billing_address_id FROM payment_card WHERE customer_id = :customer_id" + ) + + def test_update_rule_target_fields( + self, erasure_policy, example_datasets, integration_postgres_config + ): + dataset = FidesopsDataset(**example_datasets[0]) + graph = convert_dataset_to_graph(dataset, integration_postgres_config.key) + dataset_graph = DatasetGraph(*[graph]) + traversal = Traversal(dataset_graph, {"email": "customer-1@example.com"}) + + customer_node = traversal.traversal_node_dict[ + CollectionAddress("postgres_example_test_dataset", "customer") + ] + + rule = erasure_policy.rules[0] + config = SQLQueryConfig(customer_node) + assert config.build_rule_target_fields(erasure_policy) == {rule: ["name"]} + + # Make target more broad + target = rule.targets[0] + target.data_category = DataCategory("user.provided.identifiable").value + assert config.build_rule_target_fields(erasure_policy) == { + rule: ["email", "name"] + } + + # Check different collection + address_node = traversal.traversal_node_dict[ + CollectionAddress("postgres_example_test_dataset", "address") + ] + config = SQLQueryConfig(address_node) + assert config.build_rule_target_fields(erasure_policy) == { + rule: ["city", "house", "street", "state", "zip"] + } + + def test_generate_update_stmt_one_field( + self, erasure_policy, example_datasets, integration_postgres_config + ): + dataset = FidesopsDataset(**example_datasets[0]) + graph = convert_dataset_to_graph(dataset, integration_postgres_config.key) + dataset_graph = DatasetGraph(*[graph]) + traversal = Traversal(dataset_graph, {"email": "customer-1@example.com"}) + + customer_node = traversal.traversal_node_dict[ + CollectionAddress("postgres_example_test_dataset", "customer") + ] + + config = SQLQueryConfig(customer_node) + row = { + "email": "customer-1@example.com", + "name": "John Customer", + "address_id": 1, + "id": 1, + } + + text_clause = config.generate_update_stmt(row, erasure_policy) + assert ( + text_clause.text == """UPDATE customer SET name = :name WHERE id = :id""" + ) + assert text_clause._bindparams["name"].key == "name" + assert text_clause._bindparams["name"].value is None # Null masking strategy + + def test_generate_update_stmt_multiple_fields_same_rule( + self, erasure_policy, example_datasets, integration_postgres_config + ): + dataset = FidesopsDataset(**example_datasets[0]) + graph = convert_dataset_to_graph(dataset, integration_postgres_config.key) + dataset_graph = DatasetGraph(*[graph]) + traversal = Traversal(dataset_graph, {"email": "customer-1@example.com"}) + + customer_node = traversal.traversal_node_dict[ + CollectionAddress("postgres_example_test_dataset", "customer") + ] + + config = SQLQueryConfig(customer_node) + row = { + "email": "customer-1@example.com", + "name": "John Customer", + "address_id": 1, + "id": 1, + } + + # Make target more broad + rule = erasure_policy.rules[0] + target = rule.targets[0] + target.data_category = DataCategory("user.provided.identifiable").value + + # Update rule masking strategy + rule.masking_strategy = { + "strategy": "hash", + "configuration": {"algorithm": "SHA-512"}, + } + + text_clause = config.generate_update_stmt(row, erasure_policy) + assert ( + text_clause.text + == "UPDATE customer SET email = :email,name = :name WHERE id = :id" + ) + assert text_clause._bindparams["name"].key == "name" + assert text_clause._bindparams["name"].value == HashMaskingStrategy( + HashMaskingConfiguration(algorithm="SHA-512") + ).mask("John Customer") + assert text_clause._bindparams["email"].value == HashMaskingStrategy( + HashMaskingConfiguration(algorithm="SHA-512") + ).mask("customer-1@example.com") + + def test_generate_update_stmts_from_multiple_rules( + self, erasure_policy_two_rules, example_datasets, integration_postgres_config + ): + dataset = FidesopsDataset(**example_datasets[0]) + graph = convert_dataset_to_graph(dataset, integration_postgres_config.key) + dataset_graph = DatasetGraph(*[graph]) + traversal = Traversal(dataset_graph, {"email": "customer-1@example.com"}) + row = { + "email": "customer-1@example.com", + "name": "John Customer", + "address_id": 1, + "id": 1, + } + + customer_node = traversal.traversal_node_dict[ + CollectionAddress("postgres_example_test_dataset", "customer") + ] + + config = SQLQueryConfig(customer_node) + + text_clause = config.generate_update_stmt(row, erasure_policy_two_rules) + + assert ( + text_clause.text + == "UPDATE customer SET name = :name,email = :email WHERE id = :id" + ) + # Two different masking strategies used for name and email + assert text_clause._bindparams["name"].value is None # Null masking strategy + assert ( + text_clause._bindparams["email"].value == "*****" + ) # String rewrite masking strategy + + +class TestMongoQueryConfig: + def test_generate_update_stmt_multiple_fields( + self, + erasure_policy, + example_datasets, + integration_mongodb_config, + integration_postgres_config, + ): + dataset_postgres = FidesopsDataset(**example_datasets[0]) + graph = convert_dataset_to_graph( + dataset_postgres, integration_postgres_config.key + ) + dataset_mongo = FidesopsDataset(**example_datasets[1]) + mongo_graph = convert_dataset_to_graph( + dataset_mongo, integration_mongodb_config.key + ) + dataset_graph = DatasetGraph(*[graph, mongo_graph]) + + traversal = Traversal(dataset_graph, {"email": "customer-1@example.com"}) + + customer_details = traversal.traversal_node_dict[ + CollectionAddress("mongo_test", "customer_details") + ] + + config = MongoQueryConfig(customer_details) + row = { + "birthday": "1988-01-10", + "gender": "male", + "customer_id": 1, + "_id": 1, + } + + # Make target more broad + rule = erasure_policy.rules[0] + target = rule.targets[0] + target.data_category = DataCategory("user.provided.identifiable").value + + mongo_statement = config.generate_update_stmt(row, erasure_policy) + assert mongo_statement[0] == {"_id": 1} + assert mongo_statement[1] == {"$set": {"birthday": None, "gender": None}} + + def test_generate_update_stmt_multiple_rules( + self, + erasure_policy_two_rules, + example_datasets, + integration_mongodb_config, + integration_postgres_config, + ): + dataset_postgres = FidesopsDataset(**example_datasets[0]) + graph = convert_dataset_to_graph( + dataset_postgres, integration_postgres_config.key + ) + dataset_mongo = FidesopsDataset(**example_datasets[1]) + mongo_graph = convert_dataset_to_graph( + dataset_mongo, integration_mongodb_config.key ) - == "SELECT id,name,ccn,customer_id,billing_address_id FROM payment_card WHERE customer_id = :customer_id" - ) + dataset_graph = DatasetGraph(*[graph, mongo_graph]) + + traversal = Traversal(dataset_graph, {"email": "customer-1@example.com"}) + + customer_details = traversal.traversal_node_dict[ + CollectionAddress("mongo_test", "customer_details") + ] + + config = MongoQueryConfig(customer_details) + row = { + "birthday": "1988-01-10", + "gender": "male", + "customer_id": 1, + "_id": 1, + } + + rule = erasure_policy_two_rules.rules[0] + rule.masking_strategy = { + "strategy": "hash", + "configuration": {"algorithm": "SHA-512"}, + } + target = rule.targets[0] + target.data_category = DataCategory( + "user.provided.identifiable.date_of_birth" + ).value + + rule_two = erasure_policy_two_rules.rules[1] + rule_two.masking_strategy = { + "strategy": "random_string_rewrite", + "configuration": {"length": 30}, + } + target = rule_two.targets[0] + target.data_category = DataCategory("user.provided.identifiable.gender").value + + mongo_statement = config.generate_update_stmt(row, erasure_policy_two_rules) + assert mongo_statement[0] == {"_id": 1} + assert len(mongo_statement[1]["$set"]["gender"]) == 30 + assert mongo_statement[1]["$set"]["birthday"] == HashMaskingStrategy( + HashMaskingConfiguration(algorithm="SHA-512") + ).mask("1988-01-10")