Skip to content
This repository has been archived by the owner on Nov 30, 2022. It is now read-only.

Adds MariaDB query execution tests #191

Merged
merged 3 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2059,7 +2059,7 @@
"header": [],
"body": {
"mode": "raw",
"raw": "{\n \"host\": \"mariadb_example\",\n \"port\": 3808,\n \"dbname\": \"mariadb_example\",\n \"username\": \"mariadb_user\",\n \"password\": \"mariadb_pw\"\n}",
"raw": "{\n \"host\": \"mariadb_example\",\n \"port\": 3806,\n \"dbname\": \"mariadb_example\",\n \"username\": \"mariadb_user\",\n \"password\": \"mariadb_pw\"\n}",
"options": {
"raw": {
"language": "json"
Expand Down
36 changes: 15 additions & 21 deletions src/fidesops/service/connectors/sql_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ def cursor_result_to_rows(results: CursorResult) -> List[Row]:
)
return rows

@staticmethod
def default_cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]:
"""
Convert SQLAlchemy results to a list of dictionaries
Overrides BaseConnector.cursor_result_to_rows since SQLAlchemy execute returns LegacyCursorResult for MariaDB
"""
columns: List[Column] = results.cursor.description
rows = []
for row_tuple in results:
rows.append({col[0]: row_tuple[count] for count, col in enumerate(columns)})
return rows

@abstractmethod
def build_uri(self) -> str:
"""Build a database specific uri connection string"""
Expand Down Expand Up @@ -187,18 +199,12 @@ def create_client(self) -> Engine:
echo=not self.hide_parameters,
)

# Overrides BaseConnector.cursor_result_to_rows
@staticmethod
def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]:
"""
Convert SQLAlchemy results to a list of dictionaries
Overrides BaseConnector.cursor_result_to_rows since SQLAlchemy execute returns LegacyCursorResult for MySQL
"""
columns: List[Column] = results.cursor.description
rows = []
for row_tuple in results:
rows.append({col[0]: row_tuple[count] for count, col in enumerate(columns)})
return rows
return SQLConnector.default_cursor_result_to_rows(results)


class MariaDBConnector(SQLConnector):
Expand Down Expand Up @@ -230,18 +236,12 @@ def create_client(self) -> Engine:
echo=not self.hide_parameters,
)

# Overrides BaseConnector.cursor_result_to_rows
@staticmethod
def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]:
"""
Convert SQLAlchemy results to a list of dictionaries
Overrides BaseConnector.cursor_result_to_rows since SQLAlchemy execute returns LegacyCursorResult for MariaDB
"""
columns: List[Column] = results.cursor.description
rows = []
for row_tuple in results:
rows.append({col[0]: row_tuple[count] for count, col in enumerate(columns)})
return rows
return SQLConnector.default_cursor_result_to_rows(results)


class RedshiftConnector(SQLConnector):
Expand Down Expand Up @@ -413,15 +413,9 @@ def query_config(self, node: TraversalNode) -> SQLQueryConfig:
"""Query wrapper corresponding to the input traversal_node."""
return MicrosoftSQLServerQueryConfig(node)

# Overrides BaseConnector.cursor_result_to_rows
@staticmethod
def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]:
"""
Convert SQLAlchemy results to a list of dictionaries
Overrides BaseConnector.cursor_result_to_rows since SQLAlchemy execute returns LegacyCursorResult for MsSQL
"""
columns: List[Column] = results.cursor.description
rows = []
for row_tuple in results:
rows.append({col[0]: row_tuple[count] for count, col in enumerate(columns)})
return rows
return SQLConnector.default_cursor_result_to_rows(results)
82 changes: 82 additions & 0 deletions tests/integration_tests/test_sql_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,88 @@ def test_mysql_access_request_task(db, policy, connection_config_mysql) -> None:
)


@pytest.mark.integration
def test_mariadb_access_request_task(db, policy, connection_config_mariadb) -> None:

privacy_request = PrivacyRequest(
id=f"test_mariadb_access_request_task_{random.randint(0, 1000)}"
)

v = graph_task.run_access_request(
privacy_request,
policy,
integration_db_graph("my_maria_db_1"),
[connection_config_mariadb],
{"email": "[email protected]"},
)

assert_rows_match(
v["my_maria_db_1:address"],
min_size=2,
keys=["id", "street", "city", "state", "zip"],
)
assert_rows_match(
v["my_maria_db_1:orders"],
min_size=3,
keys=["id", "customer_id", "shipping_address_id", "payment_card_id"],
)
assert_rows_match(
v["my_maria_db_1:payment_card"],
min_size=2,
keys=["id", "name", "ccn", "customer_id", "billing_address_id"],
)
assert_rows_match(
v["my_maria_db_1:customer"],
min_size=1,
keys=["id", "name", "email", "address_id"],
)

# links
assert v["my_maria_db_1:customer"][0]["email"] == "[email protected]"

logs = (
ExecutionLog.query(db=db)
.filter(ExecutionLog.privacy_request_id == privacy_request.id)
.all()
)

logs = [log.__dict__ for log in logs]
assert (
len(
records_matching_fields(
logs, dataset_name="my_maria_db_1", collection_name="customer"
)
)
> 0
)
assert (
len(
records_matching_fields(
logs, dataset_name="my_maria_db_1", collection_name="address"
)
)
> 0
)
assert (
len(
records_matching_fields(
logs, dataset_name="my_maria_db_1", collection_name="orders"
)
)
> 0
)
assert (
len(
records_matching_fields(
logs,
dataset_name="my_maria_db_1",
collection_name="payment_card",
)
)
> 0
)


@pytest.mark.integration
def test_filter_on_data_categories(
db,
Expand Down
80 changes: 79 additions & 1 deletion tests/service/privacy_request/request_runner_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
SnowflakeConnector,
RedshiftConnector,
MicrosoftSQLServerConnector,
MySQLConnector,
MySQLConnector, MariaDBConnector,
)
from fidesops.service.masking.strategy.masking_strategy_factory import get_strategy
from fidesops.service.privacy_request.request_runner_service import PrivacyRequestRunner
Expand Down Expand Up @@ -254,6 +254,45 @@ def test_create_and_process_access_request_mysql(
pr.delete(db=db)


@pytest.mark.integration
@mock.patch("fidesops.models.privacy_request.PrivacyRequest.trigger_policy_webhook")
def test_create_and_process_access_request_mariadb(
trigger_webhook_mock,
mariadb_example_test_dataset_config,
db,
cache,
policy,
policy_pre_execution_webhooks,
policy_post_execution_webhooks,
):

customer_email = "[email protected]"
data = {
"requested_at": "2021-08-30T16:09:37.359Z",
"policy_key": policy.key,
"identity": {"email": customer_email},
}

pr = get_privacy_request_results(db, policy, cache, data)

results = pr.get_results()
assert len(results.keys()) == 11

for key in results.keys():
assert results[key] is not None
assert results[key] != {}

result_key_prefix = f"EN_{pr.id}__access_request__mariadb_example_test_dataset:"
customer_key = result_key_prefix + "customer"
assert results[customer_key][0]["email"] == customer_email

visit_key = result_key_prefix + "visit"
assert results[visit_key][0]["email"] == customer_email
# Both pre-execution webhooks and both post-execution webhooks were called
assert trigger_webhook_mock.call_count == 4
pr.delete(db=db)


@pytest.mark.integration_erasure
def test_create_and_process_erasure_request_specific_category(
postgres_example_test_dataset_config,
Expand Down Expand Up @@ -371,6 +410,45 @@ def test_create_and_process_erasure_request_specific_category_mysql(
assert customer_found


@pytest.mark.integration_erasure
def test_create_and_process_erasure_request_specific_category_mariadb(
mariadb_example_test_dataset_config,
cache,
db,
generate_auth_header,
erasure_policy,
connection_config_mariadb,
):
customer_email = "[email protected]"
customer_id = 1
data = {
"requested_at": "2021-08-30T16:09:37.359Z",
"policy_key": erasure_policy.key,
"identity": {"email": customer_email},
}

pr = get_privacy_request_results(db, erasure_policy, cache, data)
pr.delete(db=db)

example_mariadb_uri = MariaDBConnector(connection_config_mariadb).build_uri()
engine = get_db_engine(database_uri=example_mariadb_uri)
SessionLocal = get_db_session(engine=engine)
integration_db = SessionLocal()
stmt = select(
column("id"),
column("name"),
).select_from(table("customer"))
res = integration_db.execute(stmt).all()

customer_found = False
for row in res:
if customer_id in row:
customer_found = True
# Check that the `name` field is `None`
assert row.name is None
assert customer_found


@pytest.mark.integration_erasure
def test_create_and_process_erasure_request_generic_category(
postgres_example_test_dataset_config,
Expand Down