Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added output_processor parameter to SQLQueryOperator and fixed bug with return_single_query_results handler when None is passed as split_statements #44781

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def return_single_query_results(sql: str | Iterable[str], return_last: bool, spl
:param split_statements: whether to split string statements.
:return: True if the hook should return single query results
"""
return isinstance(sql, str) and (return_last or not split_statements)
if split_statements is not None:
return isinstance(sql, str) and (return_last or not split_statements)
return isinstance(sql, str) and return_last


def fetch_all_handler(cursor) -> list[tuple] | None:
Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
be removed in the future. Please import it from 'airflow.providers.common.sql.hooks.handlers'."""


def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool):
def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool | None):
warnings.warn(WARNING_MESSAGE.format("return_single_query_results"), DeprecationWarning, stacklevel=2)

from airflow.providers.common.sql.hooks import handlers
Expand Down
24 changes: 21 additions & 3 deletions providers/src/airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ def _get_failed_checks(checks, col=None):
}


def default_output_processor(results: list[Any], descriptions: list[Sequence[Sequence] | None]) -> list[Any]:
return results


class BaseSQLOperator(BaseOperator):
"""
This is a base class for generic SQL Operator to get a DB Hook.
Expand Down Expand Up @@ -210,6 +214,8 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
:param autocommit: (optional) if True, each command is automatically committed (default: False).
:param parameters: (optional) the parameters to render the SQL query with.
:param handler: (optional) the function that will be applied to the cursor (default: fetch_all_handler).
:param output_processor: (optional) the function that will be applied to the result
(default: default_output_processor).
:param split_statements: (optional) if split single SQL string into statements. By default, defers
to the default value in the ``run`` method of the configured hook.
:param conn_id: the connection ID used to connect to the database
Expand All @@ -235,6 +241,13 @@ def __init__(
autocommit: bool = False,
parameters: Mapping | Iterable | None = None,
handler: Callable[[Any], list[tuple] | None] = fetch_all_handler,
output_processor: (
Callable[
[list[Any], list[Sequence[Sequence] | None]],
list[Any] | tuple[list[Sequence[Sequence] | None], list],
]
| None
) = None,
conn_id: str | None = None,
database: str | None = None,
split_statements: bool | None = None,
Expand All @@ -247,11 +260,14 @@ def __init__(
self.autocommit = autocommit
self.parameters = parameters
self.handler = handler
self._output_processor = output_processor or default_output_processor
self.split_statements = split_statements
self.return_last = return_last
self.show_return_value_in_logs = show_return_value_in_logs

def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequence] | None]) -> list[Any]:
def _process_output(
self, results: list[Any], descriptions: list[Sequence[Sequence] | None]
) -> list[Any] | tuple[list[Sequence[Sequence] | None], list]:
"""
Process output before it is returned by the operator.

Expand All @@ -270,7 +286,7 @@ def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequen
"""
if self.show_return_value_in_logs:
self.log.info("Operator output is: %s", results)
return results
return self._output_processor(results, descriptions)

def _should_run_output_processing(self) -> bool:
return self.do_xcom_push
Expand All @@ -297,7 +313,9 @@ def execute(self, context):
# single query results are going to be returned, and we return the first element
# of the list in this case from the (always) list returned by _process_output
return self._process_output([output], hook.descriptions)[-1]
return self._process_output(output, hook.descriptions)
result = self._process_output(output, hook.descriptions)
self.log.info("result: %s", result)
return result

def prepare_template(self) -> None:
"""Parse template file for attribute parameters."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
autocommit: bool = False,
parameters: Mapping | Iterable | None = None,
handler: Callable[[Any], list[tuple] | None] = ...,
output_processor: (
Callable[
[list[Any], list[Sequence[Sequence] | None]],
list[Any] | tuple[list[Sequence[Sequence] | None], list],
]
| None
) = None,
conn_id: str | None = None,
database: str | None = None,
split_statements: bool | None = None,
Expand Down
1 change: 1 addition & 0 deletions providers/tests/common/sql/hooks/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class TestHandlers:
def test_return_single_query_results(self):
assert return_single_query_results("SELECT 1", return_last=True, split_statements=False)
assert return_single_query_results("SELECT 1", return_last=False, split_statements=False)
assert return_single_query_results("SELECT 1", return_last=False, split_statements=None) is False
assert return_single_query_results(["SELECT 1"], return_last=True, split_statements=False) is False
assert return_single_query_results(["SELECT 1"], return_last=False, split_statements=False) is False
assert return_single_query_results("SELECT 1", return_last=False, split_statements=True) is False
Expand Down
19 changes: 19 additions & 0 deletions providers/tests/common/sql/operators/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,25 @@ def test_dont_xcom_push(self, mock_get_db_hook, mock_process_output):
)
mock_process_output.assert_not_called()

@mock.patch.object(SQLExecuteQueryOperator, "get_db_hook")
def test_output_processor(self, mock_get_db_hook):
data = [(1, "Alice"), (2, "Bob")]

mock_hook = MagicMock()
mock_hook.run.return_value = data
mock_hook.descriptions = ("id", "name")
mock_get_db_hook.return_value = mock_hook

operator = self._construct_operator(
sql="SELECT * FROM users;",
output_processor=lambda results, descriptions: (descriptions, results),
return_last=False,
)
descriptions, result = operator.execute(context=MagicMock())

assert descriptions == ("id", "name")
assert result == [(1, "Alice"), (2, "Bob")]


class TestColumnCheckOperator:
valid_column_mapping = {
Expand Down