diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py index 63b6457c759f2..a6883d9f08a45 100644 --- a/airflow/providers/common/sql/operators/sql.py +++ b/airflow/providers/common/sql/operators/sql.py @@ -34,7 +34,13 @@ def parse_boolean(val: str) -> Union[str, bool]: raise ValueError(f"{val!r} is not a boolean-like string value") -def _get_failed_tests(checks): +def _get_failed_checks(checks, col=None): + if col: + return [ + f"Column: {col}\nCheck: {check},\nCheck Values: {check_values}\n" + for check, check_values in checks.items() + if not check_values["success"] + ] return [ f"\tCheck: {check},\n\tCheck Values: {check_values}\n" for check, check_values in checks.items() @@ -73,6 +79,13 @@ class SQLColumnCheckOperator(BaseSQLOperator): } } + :param partition_clause: a partial SQL statement that is added to a WHERE clause in the query built by + the operator that creates partition_clauses for the checks to run on, e.g. + + .. code-block:: python + + "date = '1970-01-01'" + :param conn_id: the connection ID used to connect to the database :param database: name of database which overwrite the defined one in connection @@ -81,6 +94,8 @@ class SQLColumnCheckOperator(BaseSQLOperator): :ref:`howto/operator:SQLColumnCheckOperator` """ + template_fields = ("partition_clause",) + column_checks = { "null_check": "SUM(CASE WHEN column IS NULL THEN 1 ELSE 0 END) AS column_null_check", "distinct_check": "COUNT(DISTINCT(column)) AS column_distinct_check", @@ -94,6 +109,7 @@ def __init__( *, table: str, column_mapping: Dict[str, Dict[str, Any]], + partition_clause: Optional[str] = None, conn_id: Optional[str] = None, database: Optional[str] = None, **kwargs, @@ -105,6 +121,7 @@ def __init__( self.table = table self.column_mapping = column_mapping + self.partition_clause = partition_clause # OpenLineage needs a valid SQL query with the input/output table(s) to parse self.sql = f"SELECT * FROM {self.table};" @@ -114,8 +131,8 @@ def execute(self, context=None): for column in self.column_mapping: checks = [*self.column_mapping[column]] checks_sql = ",".join([self.column_checks[check].replace("column", column) for check in checks]) - - self.sql = f"SELECT {checks_sql} FROM {self.table};" + partition_clause_statement = f"WHERE {self.partition_clause}" if self.partition_clause else "" + self.sql = f"SELECT {checks_sql} FROM {self.table} {partition_clause_statement};" records = hook.get_first(self.sql) if not records: @@ -131,10 +148,10 @@ def execute(self, context=None): self.column_mapping[column][checks[idx]], result, tolerance ) - failed_tests.extend(_get_failed_tests(self.column_mapping[column])) + failed_tests.extend(_get_failed_checks(self.column_mapping[column], column)) if failed_tests: raise AirflowException( - f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n" + f"Test failed.\nResults:\n{records!s}\n" "The following tests have failed:" f"\n{''.join(failed_tests)}" ) @@ -249,6 +266,14 @@ class SQLTableCheckOperator(BaseSQLOperator): "column_sum_check": {"check_statement": "col_a + col_b < col_c"}, } + + :param partition_clause: a partial SQL statement that is added to a WHERE clause in the query built by + the operator that creates partition_clauses for the checks to run on, e.g. + + .. code-block:: python + + "date = '1970-01-01'" + :param conn_id: the connection ID used to connect to the database :param database: name of database which overwrite the defined one in connection @@ -257,14 +282,19 @@ class SQLTableCheckOperator(BaseSQLOperator): :ref:`howto/operator:SQLTableCheckOperator` """ - sql_check_template = "CASE WHEN check_statement THEN 1 ELSE 0 END AS check_name" - sql_min_template = "MIN(check_name)" + template_fields = ("partition_clause",) + + sql_check_template = """ + SELECT '_check_name' AS check_name, MIN(_check_name) AS check_result + FROM(SELECT CASE WHEN check_statement THEN 1 ELSE 0 END AS _check_name FROM table) + """ def __init__( self, *, table: str, checks: Dict[str, Dict[str, Any]], + partition_clause: Optional[str] = None, conn_id: Optional[str] = None, database: Optional[str] = None, **kwargs, @@ -273,38 +303,38 @@ def __init__( self.table = table self.checks = checks + self.partition_clause = partition_clause # OpenLineage needs a valid SQL query with the input/output table(s) to parse self.sql = f"SELECT * FROM {self.table};" def execute(self, context=None): hook = self.get_db_hook() - - check_names = [*self.checks] - check_mins_sql = ",".join( - self.sql_min_template.replace("check_name", check_name) for check_name in check_names - ) - checks_sql = ",".join( + checks_sql = " UNION ALL ".join( [ - self.sql_check_template.replace("check_statement", value["check_statement"]).replace( - "check_name", check_name - ) + self.sql_check_template.replace("check_statement", value["check_statement"]) + .replace("_check_name", check_name) + .replace("table", self.table) for check_name, value in self.checks.items() ] ) + partition_clause_statement = f"WHERE {self.partition_clause}" if self.partition_clause else "" + self.sql = f"SELECT check_name, check_result FROM ({checks_sql}) " + f"AS check_table {partition_clause_statement};" - self.sql = f"SELECT {check_mins_sql} FROM (SELECT {checks_sql} FROM {self.table});" - records = hook.get_first(self.sql) + records = hook.get_pandas_df(self.sql) - if not records: + if records.empty: raise AirflowException(f"The following query returned zero rows: {self.sql}") - self.log.info("Record: %s", records) + records.columns = records.columns.str.lower() + self.log.info("Record:\n%s", records) - for check in self.checks.keys(): - for result in records: - self.checks[check]["success"] = parse_boolean(str(result)) + for row in records.iterrows(): + check = row[1].get("check_name") + result = row[1].get("check_result") + self.checks[check]["success"] = parse_boolean(str(result)) - failed_tests = _get_failed_tests(self.checks) + failed_tests = _get_failed_checks(self.checks) if failed_tests: raise AirflowException( f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n" diff --git a/tests/providers/common/sql/operators/test_sql.py b/tests/providers/common/sql/operators/test_sql.py index 63ef78ba1fe0f..da53e6134fe4e 100644 --- a/tests/providers/common/sql/operators/test_sql.py +++ b/tests/providers/common/sql/operators/test_sql.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. +import pandas as pd import pytest from airflow.exceptions import AirflowException @@ -26,6 +27,9 @@ class MockHook: def get_first(self): return + def get_pandas_df(self): + return + def _get_mock_db_hook(): return MockHook() @@ -95,20 +99,32 @@ class TestTableCheckOperator: "column_sum_check": {"check_statement": "col_a + col_b < col_c"}, } - def _construct_operator(self, monkeypatch, checks, return_vals): - def get_first_return(*arg): - return return_vals + def _construct_operator(self, monkeypatch, checks, return_df): + def get_pandas_df_return(*arg): + return return_df operator = SQLTableCheckOperator(task_id="test_task", table="test_table", checks=checks) monkeypatch.setattr(operator, "get_db_hook", _get_mock_db_hook) - monkeypatch.setattr(MockHook, "get_first", get_first_return) + monkeypatch.setattr(MockHook, "get_pandas_df", get_pandas_df_return) return operator def test_pass_all_checks_check(self, monkeypatch): - operator = self._construct_operator(monkeypatch, self.checks, ('1', 'y', 'true')) + df = pd.DataFrame( + data={ + "check_name": ["row_count_check", "column_sum_check"], + "check_result": [ + "1", + "y", + ], + } + ) + operator = self._construct_operator(monkeypatch, self.checks, df) operator.execute() def test_fail_all_checks_check(self, monkeypatch): - operator = self._construct_operator(monkeypatch, self.checks, ('0', 'n', 'false')) + df = pd.DataFrame( + data={"check_name": ["row_count_check", "column_sum_check"], "check_result": ["0", "n"]} + ) + operator = self._construct_operator(monkeypatch, self.checks, df) with pytest.raises(AirflowException): operator.execute()