Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Common SQLCheckOperators Various Functionality Update #25164

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
dc01866
Add batching to SQL Check Operators
denimalpaca Jul 18, 2022
7c25227
Fix bug with multiple table checks
denimalpaca Jul 18, 2022
2a3df61
Update test failure logic
denimalpaca Jul 18, 2022
66922f0
Add table alias to SQLTableCheckOperator query
denimalpaca Jul 21, 2022
554e8ba
Fix formatting error in operator
denimalpaca Jul 21, 2022
cf90083
Add batching to SQL Check Operators
denimalpaca Jul 18, 2022
7c20bf6
Fix bug with multiple table checks
denimalpaca Jul 18, 2022
d364e96
Update test failure logic
denimalpaca Jul 18, 2022
3c300e7
Add table alias to SQLTableCheckOperator query
denimalpaca Jul 21, 2022
5645b5d
Fix formatting error in operator
denimalpaca Jul 21, 2022
98a28c3
Merge branch 'sql_check_operators_various_functionality_update' of gi…
denimalpaca Jul 21, 2022
ee697e3
Move alias to proper query build statement
denimalpaca Jul 21, 2022
1754fdc
Add batching to SQL Check Operators
denimalpaca Jul 18, 2022
31d0e0e
Fix bug with multiple table checks
denimalpaca Jul 18, 2022
bc4140e
Update test failure logic
denimalpaca Jul 18, 2022
24ce964
Add table alias to SQLTableCheckOperator query
denimalpaca Jul 21, 2022
987f6c2
Fix formatting error in operator
denimalpaca Jul 21, 2022
a27fc2c
Merge branch 'sql_check_operators_various_functionality_update' of gi…
denimalpaca Jul 21, 2022
404eef5
Bug fixes and updates to test and operator
denimalpaca Jul 21, 2022
01bfb2f
Remove merge conflict lines
denimalpaca Jul 21, 2022
75d59d1
Rename parameter batch to partition_clause
denimalpaca Jul 21, 2022
0388ac4
Fix typo in docstring
denimalpaca Jul 21, 2022
719a830
Reformat operator file
denimalpaca Jul 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 36 additions & 25 deletions airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ def parse_boolean(val: str) -> Union[str, bool]:
raise ValueError(f"{val!r} is not a boolean-like string value")


def _get_failed_tests(checks):
def _get_failed_checks(checks, col=None):
if col:
return [
f"Column: {col}\nCheck: {check},\nCheck Values: {check_values}\n"
for check, check_values in checks.items()
if not check_values["success"]
]
return [
f"\tCheck: {check},\n\tCheck Values: {check_values}\n"
for check, check_values in checks.items()
Expand Down Expand Up @@ -73,6 +79,7 @@ class SQLColumnCheckOperator(BaseSQLOperator):
}
}

:param batch: a SQL statement that is added to a WHERE clause to create batches
potiuk marked this conversation as resolved.
Show resolved Hide resolved
:param conn_id: the connection ID used to connect to the database
:param database: name of database which overwrite the defined one in connection

Expand All @@ -94,6 +101,7 @@ def __init__(
*,
table: str,
column_mapping: Dict[str, Dict[str, Any]],
batch: Optional[str] = None,
conn_id: Optional[str] = None,
database: Optional[str] = None,
**kwargs,
Expand All @@ -105,6 +113,7 @@ def __init__(

self.table = table
self.column_mapping = column_mapping
self.batch = batch
# OpenLineage needs a valid SQL query with the input/output table(s) to parse
self.sql = f"SELECT * FROM {self.table};"

Expand All @@ -114,8 +123,8 @@ def execute(self, context=None):
for column in self.column_mapping:
checks = [*self.column_mapping[column]]
checks_sql = ",".join([self.column_checks[check].replace("column", column) for check in checks])

self.sql = f"SELECT {checks_sql} FROM {self.table};"
batch_statement = f"WHERE {self.batch}" if self.batch else ""
self.sql = f"SELECT {checks_sql} FROM {self.table} {batch_statement};"
records = hook.get_first(self.sql)

if not records:
Expand All @@ -131,10 +140,10 @@ def execute(self, context=None):
self.column_mapping[column][checks[idx]], result, tolerance
)

failed_tests.extend(_get_failed_tests(self.column_mapping[column]))
failed_tests.extend(_get_failed_checks(self.column_mapping[column], column))
if failed_tests:
raise AirflowException(
f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n"
f"Test failed.\nResults:\n{records!s}\n"
"The following tests have failed:"
f"\n{''.join(failed_tests)}"
)
Expand Down Expand Up @@ -249,6 +258,7 @@ class SQLTableCheckOperator(BaseSQLOperator):
"column_sum_check": {"check_statement": "col_a + col_b < col_c"},
}

:param batch: a SQL statement that is added to a WHERE clause to create batches
:param conn_id: the connection ID used to connect to the database
:param database: name of database which overwrite the defined one in connection

Expand All @@ -257,14 +267,17 @@ class SQLTableCheckOperator(BaseSQLOperator):
:ref:`howto/operator:SQLTableCheckOperator`
"""

sql_check_template = "CASE WHEN check_statement THEN 1 ELSE 0 END AS check_name"
sql_min_template = "MIN(check_name)"
sql_check_template = """
SELECT '_check_name' AS check_name, MIN(_check_name) AS check_result
FROM(SELECT CASE WHEN check_statement THEN 1 ELSE 0 END AS _check_name FROM table) AS check_table
"""

def __init__(
self,
*,
table: str,
checks: Dict[str, Dict[str, Any]],
batch: Optional[str] = None,
conn_id: Optional[str] = None,
database: Optional[str] = None,
**kwargs,
Expand All @@ -273,38 +286,36 @@ def __init__(

self.table = table
self.checks = checks
self.batch = batch
# OpenLineage needs a valid SQL query with the input/output table(s) to parse
self.sql = f"SELECT * FROM {self.table};"

def execute(self, context=None):
hook = self.get_db_hook()

check_names = [*self.checks]
check_mins_sql = ",".join(
self.sql_min_template.replace("check_name", check_name) for check_name in check_names
)
checks_sql = ",".join(
checks_sql = " UNION ALL ".join(
[
self.sql_check_template.replace("check_statement", value["check_statement"]).replace(
"check_name", check_name
)
self.sql_check_template.replace("check_statement", value["check_statement"])
.replace("_check_name", check_name)
.replace("table", self.table)
for check_name, value in self.checks.items()
]
)
batch_statement = f"WHERE {self.batch}" if self.batch else ""
self.sql = f"SELECT check_name, check_result FROM ({checks_sql}) {batch_statement};"
records = hook.get_pandas_df(self.sql)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we change from getting records to getting this as a pandas dataframe? This now places a hard requirement on using pandas for this operator, where as previously pandas was almost entirely optional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this was changed because with hook.get_first, there was an issue with how the SQL was being written that caused only fully aggregated checks to be returned, unless the syntax of the SQL query was changed, but that would require either a fetch_all or get_pandas_df call as the new SQL needs to returned multiple lines. It seemed much easier and possibly more efficient to use pandas here, but if a fetch_all seems more reasonable this can be changed. Happy to explain more about the specific issue if curious.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be curious to hear more about it :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To expand a bit more, a check like col_a + col_b >= col_c would not work when there were multiple checks in the operator as the previous SELECT statement would then fail and require a GROUP BY clause iirc. So the check would either have to be in its own operator, or be amended like so: SUM(col_a) + SUM(col_b) >= SUM(col_c) which isn't quite the same check. So the query needed to be updated, and the one that I wound up using returns multiple rows. So get_first is no longer useful, and in the moment of writing it seemed that handling things with a pandas dataframe might be easier in the long term, if more complicated uses of the pulled in data were implemented. But as of now I see how it's unneeded.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.... Is the operator easy to understand by the users? I am afraid this is something we need good and clear howto and examples for it because people will have hard time using it and rais too many questions.

Copy link
Contributor Author

@denimalpaca denimalpaca Aug 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally most users won't need to learn about why .fetch_all is being used instead of .get_first 🙃 . I'm working with some users of the operator right now and seeing what's complicated to make sure the docs are robust. I also have a working example DAG showing how to use the operator (with several more planned) here.


self.sql = f"SELECT {check_mins_sql} FROM (SELECT {checks_sql} FROM {self.table});"
records = hook.get_first(self.sql)

if not records:
if records.empty:
raise AirflowException(f"The following query returned zero rows: {self.sql}")

self.log.info("Record: %s", records)
records.columns = records.columns.str.lower()
self.log.info("Record:\n%s", records)

for check in self.checks.keys():
for result in records:
self.checks[check]["success"] = parse_boolean(str(result))
for row in records.iterrows():
check = row[1].get("check_name")
result = row[1].get("check_result")
self.checks[check]["success"] = parse_boolean(str(result))

failed_tests = _get_failed_tests(self.checks)
failed_tests = _get_failed_checks(self.checks)
if failed_tests:
raise AirflowException(
f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n"
Expand Down
28 changes: 22 additions & 6 deletions tests/providers/common/sql/operators/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.

import pandas as pd
import pytest

from airflow.exceptions import AirflowException
Expand All @@ -26,6 +27,9 @@ class MockHook:
def get_first(self):
return

def get_pandas_df(self):
return


def _get_mock_db_hook():
return MockHook()
Expand Down Expand Up @@ -95,20 +99,32 @@ class TestTableCheckOperator:
"column_sum_check": {"check_statement": "col_a + col_b < col_c"},
}

def _construct_operator(self, monkeypatch, checks, return_vals):
def get_first_return(*arg):
return return_vals
def _construct_operator(self, monkeypatch, checks, return_df):
def get_pandas_df_return(*arg):
return return_df

operator = SQLTableCheckOperator(task_id="test_task", table="test_table", checks=checks)
monkeypatch.setattr(operator, "get_db_hook", _get_mock_db_hook)
monkeypatch.setattr(MockHook, "get_first", get_first_return)
monkeypatch.setattr(MockHook, "get_pandas_df", get_pandas_df_return)
return operator

def test_pass_all_checks_check(self, monkeypatch):
operator = self._construct_operator(monkeypatch, self.checks, ('1', 'y', 'true'))
df = pd.DataFrame(
data={
"test_name": ["row_count_check", "column_sum_check"],
"test_result": [
"1",
"y",
],
}
)
operator = self._construct_operator(monkeypatch, self.checks, df)
operator.execute()

def test_fail_all_checks_check(self, monkeypatch):
operator = self._construct_operator(monkeypatch, self.checks, ('0', 'n', 'false'))
df = pd.DataFrame(
data={"test_name": ["row_count_check", "column_sum_check"], "test_result": ["0", "n"]}
)
operator = self._construct_operator(monkeypatch, self.checks, df)
with pytest.raises(AirflowException):
operator.execute()