Skip to content
This repository has been archived by the owner on Nov 30, 2022. It is now read-only.

Commit

Permalink
Adds MariaDB query execution tests (#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
eastandwestwind authored Feb 9, 2022
1 parent aab9431 commit 188ebff
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2059,7 +2059,7 @@
"header": [],
"body": {
"mode": "raw",
"raw": "{\n \"host\": \"mariadb_example\",\n \"port\": 3808,\n \"dbname\": \"mariadb_example\",\n \"username\": \"mariadb_user\",\n \"password\": \"mariadb_pw\"\n}",
"raw": "{\n \"host\": \"mariadb_example\",\n \"port\": 3806,\n \"dbname\": \"mariadb_example\",\n \"username\": \"mariadb_user\",\n \"password\": \"mariadb_pw\"\n}",
"options": {
"raw": {
"language": "json"
Expand Down
36 changes: 15 additions & 21 deletions src/fidesops/service/connectors/sql_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ def cursor_result_to_rows(results: CursorResult) -> List[Row]:
)
return rows

@staticmethod
def default_cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]:
"""
Convert SQLAlchemy results to a list of dictionaries
Overrides BaseConnector.cursor_result_to_rows since SQLAlchemy execute returns LegacyCursorResult for MariaDB
"""
columns: List[Column] = results.cursor.description
rows = []
for row_tuple in results:
rows.append({col[0]: row_tuple[count] for count, col in enumerate(columns)})
return rows

@abstractmethod
def build_uri(self) -> str:
"""Build a database specific uri connection string"""
Expand Down Expand Up @@ -187,18 +199,12 @@ def create_client(self) -> Engine:
echo=not self.hide_parameters,
)

# Overrides BaseConnector.cursor_result_to_rows
@staticmethod
def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]:
"""
Convert SQLAlchemy results to a list of dictionaries
Overrides BaseConnector.cursor_result_to_rows since SQLAlchemy execute returns LegacyCursorResult for MySQL
"""
columns: List[Column] = results.cursor.description
rows = []
for row_tuple in results:
rows.append({col[0]: row_tuple[count] for count, col in enumerate(columns)})
return rows
return SQLConnector.default_cursor_result_to_rows(results)


class MariaDBConnector(SQLConnector):
Expand Down Expand Up @@ -230,18 +236,12 @@ def create_client(self) -> Engine:
echo=not self.hide_parameters,
)

# Overrides BaseConnector.cursor_result_to_rows
@staticmethod
def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]:
"""
Convert SQLAlchemy results to a list of dictionaries
Overrides BaseConnector.cursor_result_to_rows since SQLAlchemy execute returns LegacyCursorResult for MariaDB
"""
columns: List[Column] = results.cursor.description
rows = []
for row_tuple in results:
rows.append({col[0]: row_tuple[count] for count, col in enumerate(columns)})
return rows
return SQLConnector.default_cursor_result_to_rows(results)


class RedshiftConnector(SQLConnector):
Expand Down Expand Up @@ -413,15 +413,9 @@ def query_config(self, node: TraversalNode) -> SQLQueryConfig:
"""Query wrapper corresponding to the input traversal_node."""
return MicrosoftSQLServerQueryConfig(node)

# Overrides BaseConnector.cursor_result_to_rows
@staticmethod
def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]:
"""
Convert SQLAlchemy results to a list of dictionaries
Overrides BaseConnector.cursor_result_to_rows since SQLAlchemy execute returns LegacyCursorResult for MsSQL
"""
columns: List[Column] = results.cursor.description
rows = []
for row_tuple in results:
rows.append({col[0]: row_tuple[count] for count, col in enumerate(columns)})
return rows
return SQLConnector.default_cursor_result_to_rows(results)
82 changes: 82 additions & 0 deletions tests/integration_tests/test_sql_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,88 @@ def test_mysql_access_request_task(db, policy, connection_config_mysql) -> None:
)


@pytest.mark.integration
def test_mariadb_access_request_task(db, policy, connection_config_mariadb) -> None:

privacy_request = PrivacyRequest(
id=f"test_mariadb_access_request_task_{random.randint(0, 1000)}"
)

v = graph_task.run_access_request(
privacy_request,
policy,
integration_db_graph("my_maria_db_1"),
[connection_config_mariadb],
{"email": "[email protected]"},
)

assert_rows_match(
v["my_maria_db_1:address"],
min_size=2,
keys=["id", "street", "city", "state", "zip"],
)
assert_rows_match(
v["my_maria_db_1:orders"],
min_size=3,
keys=["id", "customer_id", "shipping_address_id", "payment_card_id"],
)
assert_rows_match(
v["my_maria_db_1:payment_card"],
min_size=2,
keys=["id", "name", "ccn", "customer_id", "billing_address_id"],
)
assert_rows_match(
v["my_maria_db_1:customer"],
min_size=1,
keys=["id", "name", "email", "address_id"],
)

# links
assert v["my_maria_db_1:customer"][0]["email"] == "[email protected]"

logs = (
ExecutionLog.query(db=db)
.filter(ExecutionLog.privacy_request_id == privacy_request.id)
.all()
)

logs = [log.__dict__ for log in logs]
assert (
len(
records_matching_fields(
logs, dataset_name="my_maria_db_1", collection_name="customer"
)
)
> 0
)
assert (
len(
records_matching_fields(
logs, dataset_name="my_maria_db_1", collection_name="address"
)
)
> 0
)
assert (
len(
records_matching_fields(
logs, dataset_name="my_maria_db_1", collection_name="orders"
)
)
> 0
)
assert (
len(
records_matching_fields(
logs,
dataset_name="my_maria_db_1",
collection_name="payment_card",
)
)
> 0
)


@pytest.mark.integration
def test_filter_on_data_categories(
db,
Expand Down
80 changes: 79 additions & 1 deletion tests/service/privacy_request/request_runner_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
SnowflakeConnector,
RedshiftConnector,
MicrosoftSQLServerConnector,
MySQLConnector,
MySQLConnector, MariaDBConnector,
)
from fidesops.service.masking.strategy.masking_strategy_factory import get_strategy
from fidesops.service.privacy_request.request_runner_service import PrivacyRequestRunner
Expand Down Expand Up @@ -254,6 +254,45 @@ def test_create_and_process_access_request_mysql(
pr.delete(db=db)


@pytest.mark.integration
@mock.patch("fidesops.models.privacy_request.PrivacyRequest.trigger_policy_webhook")
def test_create_and_process_access_request_mariadb(
trigger_webhook_mock,
mariadb_example_test_dataset_config,
db,
cache,
policy,
policy_pre_execution_webhooks,
policy_post_execution_webhooks,
):

customer_email = "[email protected]"
data = {
"requested_at": "2021-08-30T16:09:37.359Z",
"policy_key": policy.key,
"identity": {"email": customer_email},
}

pr = get_privacy_request_results(db, policy, cache, data)

results = pr.get_results()
assert len(results.keys()) == 11

for key in results.keys():
assert results[key] is not None
assert results[key] != {}

result_key_prefix = f"EN_{pr.id}__access_request__mariadb_example_test_dataset:"
customer_key = result_key_prefix + "customer"
assert results[customer_key][0]["email"] == customer_email

visit_key = result_key_prefix + "visit"
assert results[visit_key][0]["email"] == customer_email
# Both pre-execution webhooks and both post-execution webhooks were called
assert trigger_webhook_mock.call_count == 4
pr.delete(db=db)


@pytest.mark.integration_erasure
def test_create_and_process_erasure_request_specific_category(
postgres_example_test_dataset_config,
Expand Down Expand Up @@ -371,6 +410,45 @@ def test_create_and_process_erasure_request_specific_category_mysql(
assert customer_found


@pytest.mark.integration_erasure
def test_create_and_process_erasure_request_specific_category_mariadb(
mariadb_example_test_dataset_config,
cache,
db,
generate_auth_header,
erasure_policy,
connection_config_mariadb,
):
customer_email = "[email protected]"
customer_id = 1
data = {
"requested_at": "2021-08-30T16:09:37.359Z",
"policy_key": erasure_policy.key,
"identity": {"email": customer_email},
}

pr = get_privacy_request_results(db, erasure_policy, cache, data)
pr.delete(db=db)

example_mariadb_uri = MariaDBConnector(connection_config_mariadb).build_uri()
engine = get_db_engine(database_uri=example_mariadb_uri)
SessionLocal = get_db_session(engine=engine)
integration_db = SessionLocal()
stmt = select(
column("id"),
column("name"),
).select_from(table("customer"))
res = integration_db.execute(stmt).all()

customer_found = False
for row in res:
if customer_id in row:
customer_found = True
# Check that the `name` field is `None`
assert row.name is None
assert customer_found


@pytest.mark.integration_erasure
def test_create_and_process_erasure_request_generic_category(
postgres_example_test_dataset_config,
Expand Down

0 comments on commit 188ebff

Please sign in to comment.