From 710fcae3293faadb05b2fda74c21b556bcbbb22a Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 9 Dec 2024 08:37:58 +0100 Subject: [PATCH 1/5] refactor: Added output_processor parameter to SQLQueryOperator --- .../providers/common/sql/operators/sql.py | 12 ++++++++++-- .../providers/common/sql/operators/sql.pyi | 1 + .../tests/common/sql/operators/test_sql.py | 18 +++++++++++++++++- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/providers/src/airflow/providers/common/sql/operators/sql.py b/providers/src/airflow/providers/common/sql/operators/sql.py index 3643d01b28eb9..77aed960bb71c 100644 --- a/providers/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/src/airflow/providers/common/sql/operators/sql.py @@ -116,6 +116,10 @@ def _get_failed_checks(checks, col=None): } +def default_output_processor(results: list[Any], descriptions: list[Sequence[Sequence] | None]) -> list[Any]: + return results + + class BaseSQLOperator(BaseOperator): """ This is a base class for generic SQL Operator to get a DB Hook. @@ -210,6 +214,8 @@ class SQLExecuteQueryOperator(BaseSQLOperator): :param autocommit: (optional) if True, each command is automatically committed (default: False). :param parameters: (optional) the parameters to render the SQL query with. :param handler: (optional) the function that will be applied to the cursor (default: fetch_all_handler). + :param output_processor: (optional) the function that will be applied to the result + (default: default_output_processor). :param split_statements: (optional) if split single SQL string into statements. By default, defers to the default value in the ``run`` method of the configured hook. :param conn_id: the connection ID used to connect to the database @@ -235,6 +241,7 @@ def __init__( autocommit: bool = False, parameters: Mapping | Iterable | None = None, handler: Callable[[Any], list[tuple] | None] = fetch_all_handler, + output_processor: Callable[[list[Any], list[Sequence[Sequence] | None]], list[Any] | tuple[list[Sequence[Sequence] | None], list]] | None = None, conn_id: str | None = None, database: str | None = None, split_statements: bool | None = None, @@ -247,11 +254,12 @@ def __init__( self.autocommit = autocommit self.parameters = parameters self.handler = handler + self._output_processor = output_processor or default_output_processor self.split_statements = split_statements self.return_last = return_last self.show_return_value_in_logs = show_return_value_in_logs - def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequence] | None]) -> list[Any]: + def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequence] | None]) -> list[Any] | tuple[list[Sequence[Sequence] | None], list]: """ Process output before it is returned by the operator. @@ -270,7 +278,7 @@ def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequen """ if self.show_return_value_in_logs: self.log.info("Operator output is: %s", results) - return results + return self._output_processor(results=results, descriptions=descriptions) def _should_run_output_processing(self) -> bool: return self.do_xcom_push diff --git a/providers/src/airflow/providers/common/sql/operators/sql.pyi b/providers/src/airflow/providers/common/sql/operators/sql.pyi index 6921e3411ea01..d4d91debafb95 100644 --- a/providers/src/airflow/providers/common/sql/operators/sql.pyi +++ b/providers/src/airflow/providers/common/sql/operators/sql.pyi @@ -78,6 +78,7 @@ class SQLExecuteQueryOperator(BaseSQLOperator): autocommit: bool = False, parameters: Mapping | Iterable | None = None, handler: Callable[[Any], list[tuple] | None] = ..., + output_processor: Callable[[list[Any], list[Sequence[Sequence] | None]], list[Any] | tuple[list[Sequence[Sequence] | None], list]] | None = None, conn_id: str | None = None, database: str | None = None, split_statements: bool | None = None, diff --git a/providers/tests/common/sql/operators/test_sql.py b/providers/tests/common/sql/operators/test_sql.py index 544fcb40b4d3f..4fbabaa10ec20 100644 --- a/providers/tests/common/sql/operators/test_sql.py +++ b/providers/tests/common/sql/operators/test_sql.py @@ -23,7 +23,6 @@ from unittest.mock import MagicMock import pytest - from airflow import DAG from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import Connection, DagRun, TaskInstance as TI, XCom @@ -148,6 +147,23 @@ def test_dont_xcom_push(self, mock_get_db_hook, mock_process_output): ) mock_process_output.assert_not_called() + @mock.patch.object(SQLExecuteQueryOperator, "get_db_hook") + def test_output_processor(self, mock_get_db_hook): + data = [(1, "Alice"), (2, "Bob")] + + mock_hook = MagicMock() + mock_hook.run.return_value = data + mock_hook.descriptions = ("id", "name") + mock_get_db_hook.return_value = mock_hook + + operator = self._construct_operator( + sql="SELECT * FROM users;", + output_processor=lambda results, descriptions: (descriptions, results), + ) + result = operator.execute(context=MagicMock()) + + assert result == [("id", "name"), [(1, "Alice"), (2, "Bob")]] + class TestColumnCheckOperator: valid_column_mapping = { From c5f3b89621d3d7a36211c7e35e107b55ad7c3079 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 9 Dec 2024 10:00:43 +0100 Subject: [PATCH 2/5] refactor: Reformatted SQLQueryOperator --- .../airflow/providers/common/sql/operators/sql.py | 12 ++++++++++-- .../airflow/providers/common/sql/operators/sql.pyi | 8 +++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/providers/src/airflow/providers/common/sql/operators/sql.py b/providers/src/airflow/providers/common/sql/operators/sql.py index 77aed960bb71c..a034ebdb56402 100644 --- a/providers/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/src/airflow/providers/common/sql/operators/sql.py @@ -241,7 +241,13 @@ def __init__( autocommit: bool = False, parameters: Mapping | Iterable | None = None, handler: Callable[[Any], list[tuple] | None] = fetch_all_handler, - output_processor: Callable[[list[Any], list[Sequence[Sequence] | None]], list[Any] | tuple[list[Sequence[Sequence] | None], list]] | None = None, + output_processor: ( + Callable[ + [list[Any], list[Sequence[Sequence] | None]], + list[Any] | tuple[list[Sequence[Sequence] | None], list], + ] + | None + ) = None, conn_id: str | None = None, database: str | None = None, split_statements: bool | None = None, @@ -259,7 +265,9 @@ def __init__( self.return_last = return_last self.show_return_value_in_logs = show_return_value_in_logs - def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequence] | None]) -> list[Any] | tuple[list[Sequence[Sequence] | None], list]: + def _process_output( + self, results: list[Any], descriptions: list[Sequence[Sequence] | None] + ) -> list[Any] | tuple[list[Sequence[Sequence] | None], list]: """ Process output before it is returned by the operator. diff --git a/providers/src/airflow/providers/common/sql/operators/sql.pyi b/providers/src/airflow/providers/common/sql/operators/sql.pyi index d4d91debafb95..6f89fc8b6ebb2 100644 --- a/providers/src/airflow/providers/common/sql/operators/sql.pyi +++ b/providers/src/airflow/providers/common/sql/operators/sql.pyi @@ -78,7 +78,13 @@ class SQLExecuteQueryOperator(BaseSQLOperator): autocommit: bool = False, parameters: Mapping | Iterable | None = None, handler: Callable[[Any], list[tuple] | None] = ..., - output_processor: Callable[[list[Any], list[Sequence[Sequence] | None]], list[Any] | tuple[list[Sequence[Sequence] | None], list]] | None = None, + output_processor: ( + Callable[ + [list[Any], list[Sequence[Sequence] | None]], + list[Any] | tuple[list[Sequence[Sequence] | None], list], + ] + | None + ) = None, conn_id: str | None = None, database: str | None = None, split_statements: bool | None = None, From 45126297086b837748b9424b669ead57c29778c5 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 9 Dec 2024 12:08:37 +0100 Subject: [PATCH 3/5] refactor: Fixed return_single_query_results when none is passed as split_statements --- .../src/airflow/providers/common/sql/hooks/handlers.py | 4 +++- providers/src/airflow/providers/common/sql/hooks/sql.py | 2 +- providers/src/airflow/providers/common/sql/operators/sql.py | 6 ++++-- providers/tests/common/sql/hooks/test_handlers.py | 1 + providers/tests/common/sql/operators/test_sql.py | 6 ++++-- 5 files changed, 13 insertions(+), 6 deletions(-) diff --git a/providers/src/airflow/providers/common/sql/hooks/handlers.py b/providers/src/airflow/providers/common/sql/hooks/handlers.py index 3636cc214d213..b399dc0023f6f 100644 --- a/providers/src/airflow/providers/common/sql/hooks/handlers.py +++ b/providers/src/airflow/providers/common/sql/hooks/handlers.py @@ -44,7 +44,9 @@ def return_single_query_results(sql: str | Iterable[str], return_last: bool, spl :param split_statements: whether to split string statements. :return: True if the hook should return single query results """ - return isinstance(sql, str) and (return_last or not split_statements) + if split_statements is not None: + return isinstance(sql, str) and (return_last or not split_statements) + return isinstance(sql, str) and return_last def fetch_all_handler(cursor) -> list[tuple] | None: diff --git a/providers/src/airflow/providers/common/sql/hooks/sql.py b/providers/src/airflow/providers/common/sql/hooks/sql.py index bd8780a750dbd..f4d107f0c5f3e 100644 --- a/providers/src/airflow/providers/common/sql/hooks/sql.py +++ b/providers/src/airflow/providers/common/sql/hooks/sql.py @@ -59,7 +59,7 @@ be removed in the future. Please import it from 'airflow.providers.common.sql.hooks.handlers'.""" -def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool): +def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool | None): warnings.warn(WARNING_MESSAGE.format("return_single_query_results"), DeprecationWarning, stacklevel=2) from airflow.providers.common.sql.hooks import handlers diff --git a/providers/src/airflow/providers/common/sql/operators/sql.py b/providers/src/airflow/providers/common/sql/operators/sql.py index a034ebdb56402..b0f5fce6a2b8a 100644 --- a/providers/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/src/airflow/providers/common/sql/operators/sql.py @@ -286,7 +286,7 @@ def _process_output( """ if self.show_return_value_in_logs: self.log.info("Operator output is: %s", results) - return self._output_processor(results=results, descriptions=descriptions) + return self._output_processor(results, descriptions) def _should_run_output_processing(self) -> bool: return self.do_xcom_push @@ -313,7 +313,9 @@ def execute(self, context): # single query results are going to be returned, and we return the first element # of the list in this case from the (always) list returned by _process_output return self._process_output([output], hook.descriptions)[-1] - return self._process_output(output, hook.descriptions) + result = self._process_output(output, hook.descriptions) + self.log.info("result: %s", result) + return result def prepare_template(self) -> None: """Parse template file for attribute parameters.""" diff --git a/providers/tests/common/sql/hooks/test_handlers.py b/providers/tests/common/sql/hooks/test_handlers.py index 9adf8df67c82c..8fd3ed8b65f18 100644 --- a/providers/tests/common/sql/hooks/test_handlers.py +++ b/providers/tests/common/sql/hooks/test_handlers.py @@ -30,6 +30,7 @@ class TestHandlers: def test_return_single_query_results(self): assert return_single_query_results("SELECT 1", return_last=True, split_statements=False) assert return_single_query_results("SELECT 1", return_last=False, split_statements=False) + assert return_single_query_results("SELECT 1", return_last=False, split_statements=None) is False assert return_single_query_results(["SELECT 1"], return_last=True, split_statements=False) is False assert return_single_query_results(["SELECT 1"], return_last=False, split_statements=False) is False assert return_single_query_results("SELECT 1", return_last=False, split_statements=True) is False diff --git a/providers/tests/common/sql/operators/test_sql.py b/providers/tests/common/sql/operators/test_sql.py index 4fbabaa10ec20..a94a074890ccd 100644 --- a/providers/tests/common/sql/operators/test_sql.py +++ b/providers/tests/common/sql/operators/test_sql.py @@ -159,10 +159,12 @@ def test_output_processor(self, mock_get_db_hook): operator = self._construct_operator( sql="SELECT * FROM users;", output_processor=lambda results, descriptions: (descriptions, results), + return_last=False, ) - result = operator.execute(context=MagicMock()) + descriptions, result = operator.execute(context=MagicMock()) - assert result == [("id", "name"), [(1, "Alice"), (2, "Bob")]] + assert descriptions == ("id", "name") + assert result == [(1, "Alice"), (2, "Bob")] class TestColumnCheckOperator: From d1e12f60a0c6e17d27b551519b1284e775ef861a Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 9 Dec 2024 12:43:33 +0100 Subject: [PATCH 4/5] refactor: Reformatted SQLExecuteOperator --- providers/src/airflow/providers/common/sql/operators/sql.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/providers/src/airflow/providers/common/sql/operators/sql.py b/providers/src/airflow/providers/common/sql/operators/sql.py index b0f5fce6a2b8a..61339c1260425 100644 --- a/providers/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/src/airflow/providers/common/sql/operators/sql.py @@ -243,9 +243,9 @@ def __init__( handler: Callable[[Any], list[tuple] | None] = fetch_all_handler, output_processor: ( Callable[ - [list[Any], list[Sequence[Sequence] | None]], - list[Any] | tuple[list[Sequence[Sequence] | None], list], - ] + [list[Any], list[Sequence[Sequence] | None]], + list[Any] | tuple[list[Sequence[Sequence] | None], list], + ] | None ) = None, conn_id: str | None = None, From 8f303e3afd583ebcd9d20ba9c2c9cfc3c3936ea0 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 9 Dec 2024 14:07:11 +0100 Subject: [PATCH 5/5] refactor: Added white line --- providers/tests/common/sql/operators/test_sql.py | 1 + 1 file changed, 1 insertion(+) diff --git a/providers/tests/common/sql/operators/test_sql.py b/providers/tests/common/sql/operators/test_sql.py index bc60c6c3d49a7..133d51ac75753 100644 --- a/providers/tests/common/sql/operators/test_sql.py +++ b/providers/tests/common/sql/operators/test_sql.py @@ -23,6 +23,7 @@ from unittest.mock import MagicMock import pytest + from airflow import DAG from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import Connection, DagRun, TaskInstance as TI, XCom