From e6f445129a998eab62d71bd91b4a5f46cd77c1de Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Thu, 19 Oct 2023 15:58:48 +0400 Subject: [PATCH] Pass additional arguments from Slack's Operators/Notifiers to Hooks (#35039) --- .../providers/slack/notifications/slack.py | 27 +++++++- .../slack/notifications/slack_webhook.py | 14 ++++- airflow/providers/slack/operators/slack.py | 27 +++++++- .../slack/operators/slack_webhook.py | 18 +++++- .../providers/slack/transfers/sql_to_slack.py | 30 ++++++++- .../slack/notifications/test_slack.py | 30 ++++++++- .../slack/notifications/test_slack_webhook.py | 26 +++++++- tests/providers/slack/operators/test_slack.py | 32 +++++++++- .../slack/operators/test_slack_webhook.py | 24 ++++++-- .../slack/transfers/test_sql_to_slack.py | 61 +++++++++++++++++-- 10 files changed, 261 insertions(+), 28 deletions(-) diff --git a/airflow/providers/slack/notifications/slack.py b/airflow/providers/slack/notifications/slack.py index f2359b6fcb8ad..5ff63b362889e 100644 --- a/airflow/providers/slack/notifications/slack.py +++ b/airflow/providers/slack/notifications/slack.py @@ -19,9 +19,10 @@ import json from functools import cached_property -from typing import Sequence +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowOptionalProviderFeatureException +from airflow.providers.slack.hooks.slack import SlackHook try: from airflow.notifications.basenotifier import BaseNotifier @@ -30,7 +31,8 @@ "Failed to import BaseNotifier. This feature is only available in Airflow versions >= 2.6.0" ) -from airflow.providers.slack.hooks.slack import SlackHook +if TYPE_CHECKING: + from slack_sdk.http_retry import RetryHandler ICON_URL: str = "https://raw.githubusercontent.com/apache/airflow/2.5.0/airflow/www/static/pin_100.png" @@ -46,6 +48,11 @@ class SlackNotifier(BaseNotifier): :param icon_url: The icon to use for the message. Optional :param attachments: A list of attachments to send with the message. Optional :param blocks: A list of blocks to send with the message. Optional + :param timeout: The maximum number of seconds the client will wait to connect + and receive a response from Slack. Optional + :param base_url: A string representing the Slack API base URL. Optional + :param proxy: Proxy to make the Slack API call. Optional + :param retry_handlers: List of handlers to customize retry logic in ``slack_sdk.WebClient``. Optional """ template_fields = ("text", "channel", "username", "attachments", "blocks") @@ -60,6 +67,10 @@ def __init__( icon_url: str = ICON_URL, attachments: Sequence = (), blocks: Sequence = (), + base_url: str | None = None, + proxy: str | None = None, + timeout: int | None = None, + retry_handlers: list[RetryHandler] | None = None, ): super().__init__() self.slack_conn_id = slack_conn_id @@ -69,11 +80,21 @@ def __init__( self.icon_url = icon_url self.attachments = attachments self.blocks = blocks + self.base_url = base_url + self.timeout = timeout + self.proxy = proxy + self.retry_handlers = retry_handlers @cached_property def hook(self) -> SlackHook: """Slack Hook.""" - return SlackHook(slack_conn_id=self.slack_conn_id) + return SlackHook( + slack_conn_id=self.slack_conn_id, + base_url=self.base_url, + timeout=self.timeout, + proxy=self.proxy, + retry_handlers=self.retry_handlers, + ) def notify(self, context): """Send a message to a Slack Channel.""" diff --git a/airflow/providers/slack/notifications/slack_webhook.py b/airflow/providers/slack/notifications/slack_webhook.py index 67616e5a51d2a..0fab447b6e282 100644 --- a/airflow/providers/slack/notifications/slack_webhook.py +++ b/airflow/providers/slack/notifications/slack_webhook.py @@ -18,6 +18,7 @@ from __future__ import annotations from functools import cached_property +from typing import TYPE_CHECKING from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook @@ -29,6 +30,9 @@ "Failed to import BaseNotifier. This feature is only available in Airflow versions >= 2.6.0" ) +if TYPE_CHECKING: + from slack_sdk.http_retry import RetryHandler + class SlackWebhookNotifier(BaseNotifier): """ @@ -45,9 +49,10 @@ class SlackWebhookNotifier(BaseNotifier): :param unfurl_links: Option to indicate whether text url should unfurl. Optional :param unfurl_media: Option to indicate whether media url should unfurl. Optional :param timeout: The maximum number of seconds the client will wait to connect. Optional - and receive a response from Slack. If not set than default WebhookClient value will use. + and receive a response from Slack. Optional :param proxy: Proxy to make the Slack Incoming Webhook call. Optional :param attachments: A list of attachments to send with the message. Optional + :param retry_handlers: List of handlers to customize retry logic in ``slack_sdk.WebhookClient``. Optional """ template_fields = ("slack_webhook_conn_id", "text", "attachments", "blocks", "proxy", "timeout") @@ -63,6 +68,7 @@ def __init__( proxy: str | None = None, timeout: int | None = None, attachments: list | None = None, + retry_handlers: list[RetryHandler] | None = None, ): super().__init__() self.slack_webhook_conn_id = slack_webhook_conn_id @@ -73,12 +79,16 @@ def __init__( self.unfurl_media = unfurl_media self.timeout = timeout self.proxy = proxy + self.retry_handlers = retry_handlers @cached_property def hook(self) -> SlackWebhookHook: """Slack Incoming Webhook Hook.""" return SlackWebhookHook( - slack_webhook_conn_id=self.slack_webhook_conn_id, proxy=self.proxy, timeout=self.timeout + slack_webhook_conn_id=self.slack_webhook_conn_id, + proxy=self.proxy, + timeout=self.timeout, + retry_handlers=self.retry_handlers, ) def notify(self, context): diff --git a/airflow/providers/slack/operators/slack.py b/airflow/providers/slack/operators/slack.py index adf21a93722d0..ffefbbac3aeed 100644 --- a/airflow/providers/slack/operators/slack.py +++ b/airflow/providers/slack/operators/slack.py @@ -20,12 +20,15 @@ import json import warnings from functools import cached_property -from typing import Any, Sequence +from typing import TYPE_CHECKING, Any, Sequence from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.models import BaseOperator from airflow.providers.slack.hooks.slack import SlackHook +if TYPE_CHECKING: + from slack_sdk.http_retry import RetryHandler + class SlackAPIOperator(BaseOperator): """Base Slack Operator class. @@ -34,7 +37,11 @@ class SlackAPIOperator(BaseOperator): which its password is Slack API token. :param method: The Slack API Method to Call (https://api.slack.com/methods). Optional :param api_params: API Method call parameters (https://api.slack.com/methods). Optional - :param client_args: Slack Hook parameters. Optional. Check airflow.providers.slack.hooks.SlackHook + :param timeout: The maximum number of seconds the client will wait to connect + and receive a response from Slack. Optional + :param base_url: A string representing the Slack API base URL. Optional + :param proxy: Proxy to make the Slack API call. Optional + :param retry_handlers: List of handlers to customize retry logic in ``slack_sdk.WebClient``. Optional """ def __init__( @@ -43,17 +50,31 @@ def __init__( slack_conn_id: str = SlackHook.default_conn_name, method: str | None = None, api_params: dict | None = None, + base_url: str | None = None, + proxy: str | None = None, + timeout: int | None = None, + retry_handlers: list[RetryHandler] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.slack_conn_id = slack_conn_id self.method = method self.api_params = api_params + self.base_url = base_url + self.timeout = timeout + self.proxy = proxy + self.retry_handlers = retry_handlers @cached_property def hook(self) -> SlackHook: """Slack Hook.""" - return SlackHook(slack_conn_id=self.slack_conn_id) + return SlackHook( + slack_conn_id=self.slack_conn_id, + base_url=self.base_url, + timeout=self.timeout, + proxy=self.proxy, + retry_handlers=self.retry_handlers, + ) def construct_api_call_params(self) -> Any: """API call parameters used by the execute function. diff --git a/airflow/providers/slack/operators/slack_webhook.py b/airflow/providers/slack/operators/slack_webhook.py index f386adf6b8d01..9a36a806e9284 100644 --- a/airflow/providers/slack/operators/slack_webhook.py +++ b/airflow/providers/slack/operators/slack_webhook.py @@ -24,6 +24,8 @@ from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook if TYPE_CHECKING: + from slack_sdk.http_retry import RetryHandler + from airflow.utils.context import Context @@ -51,7 +53,10 @@ class SlackWebhookOperator(BaseOperator): :param username: The username to post to slack with :param icon_emoji: The emoji to use as icon for the user posting to Slack :param icon_url: The icon image URL string to use in place of the default icon. - :param proxy: Proxy to use to make the Slack webhook call + :param proxy: Proxy to make the Slack Incoming Webhook call. Optional + :param timeout: The maximum number of seconds the client will wait to connect + and receive a response from Slack. Optional + :param retry_handlers: List of handlers to customize retry logic in ``slack_sdk.WebhookClient``. Optional """ template_fields: Sequence[str] = ( @@ -75,6 +80,8 @@ def __init__( icon_emoji: str | None = None, icon_url: str | None = None, proxy: str | None = None, + timeout: int | None = None, + retry_handlers: list[RetryHandler] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -87,11 +94,18 @@ def __init__( self.username = username self.icon_emoji = icon_emoji self.icon_url = icon_url + self.timeout = timeout + self.retry_handlers = retry_handlers @cached_property def hook(self) -> SlackWebhookHook: """Create and return an SlackWebhookHook (cached).""" - return SlackWebhookHook(slack_webhook_conn_id=self.slack_webhook_conn_id, proxy=self.proxy) + return SlackWebhookHook( + slack_webhook_conn_id=self.slack_webhook_conn_id, + proxy=self.proxy, + timeout=self.timeout, + retry_handlers=self.retry_handlers, + ) def execute(self, context: Context) -> None: """Call the SlackWebhookHook to post the provided Slack message.""" diff --git a/airflow/providers/slack/transfers/sql_to_slack.py b/airflow/providers/slack/transfers/sql_to_slack.py index 537abe3b9b56b..9c127d77a1350 100644 --- a/airflow/providers/slack/transfers/sql_to_slack.py +++ b/airflow/providers/slack/transfers/sql_to_slack.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: import pandas as pd + from slack_sdk.http_retry import RetryHandler from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.utils.context import Context @@ -44,6 +45,10 @@ class BaseSqlToSlackOperator(BaseOperator): :param sql_hook_params: Extra config params to be passed to the underlying hook. Should match the desired hook constructor params. :param parameters: The parameters to pass to the SQL query. + :param slack_proxy: Proxy to make the Slack Incoming Webhook / API calls. Optional + :param slack_timeout: The maximum number of seconds the client will wait to connect + and receive a response from Slack. Optional + :param slack_retry_handlers: List of handlers to customize retry logic. Optional """ def __init__( @@ -53,6 +58,9 @@ def __init__( sql_conn_id: str, sql_hook_params: dict | None = None, parameters: Iterable | Mapping[str, Any] | None = None, + slack_proxy: str | None = None, + slack_timeout: int | None = None, + slack_retry_handlers: list[RetryHandler] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -60,6 +68,9 @@ def __init__( self.sql_hook_params = sql_hook_params self.sql = sql self.parameters = parameters + self.slack_proxy = slack_proxy + self.slack_timeout = slack_timeout + self.slack_retry_handlers = slack_retry_handlers def _get_hook(self) -> DbApiHook: self.log.debug("Get connection for %s", self.sql_conn_id) @@ -146,7 +157,12 @@ def _render_and_send_slack_message(self, context, df) -> None: slack_hook.send(text=self.slack_message, channel=self.slack_channel) def _get_slack_hook(self) -> SlackWebhookHook: - return SlackWebhookHook(slack_webhook_conn_id=self.slack_conn_id) + return SlackWebhookHook( + slack_webhook_conn_id=self.slack_conn_id, + proxy=self.slack_proxy, + timeout=self.slack_timeout, + retry_handlers=self.slack_retry_handlers, + ) def render_template_fields(self, context, jinja_env=None) -> None: # If this is the first render of the template fields, exclude slack_message from rendering since @@ -197,8 +213,10 @@ class SqlToSlackApiFileOperator(BaseSqlToSlackOperator): If omitting this parameter, then file will send to workspace. :param slack_initial_comment: The message text introducing the file in specified ``slack_channels``. :param slack_title: Title of file. + :param slack_base_url: A string representing the Slack API base URL. Optional :param df_kwargs: Keyword arguments forwarded to ``pandas.DataFrame.to_{format}()`` method. + Example: .. code-block:: python @@ -241,6 +259,7 @@ def __init__( slack_channels: str | Sequence[str] | None = None, slack_initial_comment: str | None = None, slack_title: str | None = None, + slack_base_url: str | None = None, df_kwargs: dict | None = None, **kwargs, ): @@ -252,6 +271,7 @@ def __init__( self.slack_channels = slack_channels self.slack_initial_comment = slack_initial_comment self.slack_title = slack_title + self.slack_base_url = slack_base_url self.df_kwargs = df_kwargs or {} def execute(self, context: Context) -> None: @@ -261,7 +281,13 @@ def execute(self, context: Context) -> None: supported_file_formats=self.SUPPORTED_FILE_FORMATS, ) - slack_hook = SlackHook(slack_conn_id=self.slack_conn_id) + slack_hook = SlackHook( + slack_conn_id=self.slack_conn_id, + base_url=self.slack_base_url, + timeout=self.slack_timeout, + proxy=self.slack_proxy, + retry_handlers=self.slack_retry_handlers, + ) with NamedTemporaryFile(mode="w+", suffix=f"_{self.slack_filename}") as fp: # tempfile.NamedTemporaryFile used only for create and remove temporary file, # pandas will open file in correct mode itself depend on file type. diff --git a/tests/providers/slack/notifications/test_slack.py b/tests/providers/slack/notifications/test_slack.py index 295a01d092b68..0e10bacd30907 100644 --- a/tests/providers/slack/notifications/test_slack.py +++ b/tests/providers/slack/notifications/test_slack.py @@ -19,17 +19,42 @@ from unittest import mock +import pytest + from airflow.operators.empty import EmptyOperator from airflow.providers.slack.notifications.slack import SlackNotifier, send_slack_notification +DEFAULT_HOOKS_PARAMETERS = {"base_url": None, "timeout": None, "proxy": None, "retry_handlers": None} + class TestSlackNotifier: @mock.patch("airflow.providers.slack.notifications.slack.SlackHook") - def test_slack_notifier(self, mock_slack_hook, dag_maker): + @pytest.mark.parametrize( + "extra_kwargs, hook_extra_kwargs", + [ + pytest.param({}, DEFAULT_HOOKS_PARAMETERS, id="default-hook-parameters"), + pytest.param( + { + "base_url": "https://foo.bar", + "timeout": 42, + "proxy": "http://spam.egg", + "retry_handlers": [], + }, + { + "base_url": "https://foo.bar", + "timeout": 42, + "proxy": "http://spam.egg", + "retry_handlers": [], + }, + id="with-extra-hook-parameters", + ), + ], + ) + def test_slack_notifier(self, mock_slack_hook, dag_maker, extra_kwargs, hook_extra_kwargs): with dag_maker("test_slack_notifier") as dag: EmptyOperator(task_id="task1") - notifier = send_slack_notification(text="test") + notifier = send_slack_notification(slack_conn_id="test_conn_id", text="test", **extra_kwargs) notifier({"dag": dag}) mock_slack_hook.return_value.call.assert_called_once_with( "chat.postMessage", @@ -43,6 +68,7 @@ def test_slack_notifier(self, mock_slack_hook, dag_maker): "blocks": "[]", }, ) + mock_slack_hook.assert_called_once_with(slack_conn_id="test_conn_id", **hook_extra_kwargs) @mock.patch("airflow.providers.slack.notifications.slack.SlackHook") def test_slack_notifier_with_notifier_class(self, mock_slack_hook, dag_maker): diff --git a/tests/providers/slack/notifications/test_slack_webhook.py b/tests/providers/slack/notifications/test_slack_webhook.py index e595766bfc19e..12f27c4b23b33 100644 --- a/tests/providers/slack/notifications/test_slack_webhook.py +++ b/tests/providers/slack/notifications/test_slack_webhook.py @@ -19,21 +19,42 @@ from unittest import mock +import pytest + from airflow.operators.empty import EmptyOperator from airflow.providers.slack.notifications.slack_webhook import ( SlackWebhookNotifier, send_slack_webhook_notification, ) +DEFAULT_HOOKS_PARAMETERS = {"timeout": None, "proxy": None, "retry_handlers": None} + class TestSlackNotifier: def test_class_and_notifier_are_same(self): assert send_slack_webhook_notification is SlackWebhookNotifier @mock.patch("airflow.providers.slack.notifications.slack_webhook.SlackWebhookHook") - def test_slack_webhook_notifier(self, mock_slack_hook): + @pytest.mark.parametrize( + "slack_op_kwargs, hook_extra_kwargs", + [ + pytest.param({}, DEFAULT_HOOKS_PARAMETERS, id="default-hook-parameters"), + pytest.param( + {"timeout": 42, "proxy": "http://spam.egg", "retry_handlers": []}, + {"timeout": 42, "proxy": "http://spam.egg", "retry_handlers": []}, + id="with-extra-hook-parameters", + ), + ], + ) + def test_slack_webhook_notifier(self, mock_slack_hook, slack_op_kwargs, hook_extra_kwargs): notifier = send_slack_webhook_notification( - text="foo-bar", blocks="spam-egg", attachments="baz-qux", unfurl_links=True, unfurl_media=False + slack_webhook_conn_id="test_conn_id", + text="foo-bar", + blocks="spam-egg", + attachments="baz-qux", + unfurl_links=True, + unfurl_media=False, + **slack_op_kwargs, ) notifier.notify({}) mock_slack_hook.return_value.send.assert_called_once_with( @@ -43,6 +64,7 @@ def test_slack_webhook_notifier(self, mock_slack_hook): unfurl_media=False, attachments="baz-qux", ) + mock_slack_hook.assert_called_once_with(slack_webhook_conn_id="test_conn_id", **hook_extra_kwargs) @mock.patch("airflow.providers.slack.notifications.slack_webhook.SlackWebhookHook") def test_slack_webhook_templated(self, mock_slack_hook, dag_maker): diff --git a/tests/providers/slack/operators/test_slack.py b/tests/providers/slack/operators/test_slack.py index 8dfebc9b2d292..52a28ef4d8e7a 100644 --- a/tests/providers/slack/operators/test_slack.py +++ b/tests/providers/slack/operators/test_slack.py @@ -31,17 +31,43 @@ ) SLACK_API_TEST_CONNECTION_ID = "test_slack_conn_id" +DEFAULT_HOOKS_PARAMETERS = {"base_url": None, "timeout": None, "proxy": None, "retry_handlers": None} class TestSlackAPIOperator: @mock.patch("airflow.providers.slack.operators.slack.SlackHook") - def test_hook(self, mock_slack_hook_cls): + @pytest.mark.parametrize( + "slack_op_kwargs, hook_extra_kwargs", + [ + pytest.param({}, DEFAULT_HOOKS_PARAMETERS, id="default-hook-parameters"), + pytest.param( + { + "base_url": "https://foo.bar", + "timeout": 42, + "proxy": "http://spam.egg", + "retry_handlers": [], + }, + { + "base_url": "https://foo.bar", + "timeout": 42, + "proxy": "http://spam.egg", + "retry_handlers": [], + }, + id="with-extra-hook-parameters", + ), + ], + ) + def test_hook(self, mock_slack_hook_cls, slack_op_kwargs, hook_extra_kwargs): mock_slack_hook = mock_slack_hook_cls.return_value - op = SlackAPIOperator(task_id="test-mask-token", slack_conn_id=SLACK_API_TEST_CONNECTION_ID) + op = SlackAPIOperator( + task_id="test-mask-token", slack_conn_id=SLACK_API_TEST_CONNECTION_ID, **slack_op_kwargs + ) hook = op.hook assert hook == mock_slack_hook assert hook is op.hook - mock_slack_hook_cls.assert_called_once_with(slack_conn_id=SLACK_API_TEST_CONNECTION_ID) + mock_slack_hook_cls.assert_called_once_with( + slack_conn_id=SLACK_API_TEST_CONNECTION_ID, **hook_extra_kwargs + ) class TestSlackAPIPostOperator: diff --git a/tests/providers/slack/operators/test_slack_webhook.py b/tests/providers/slack/operators/test_slack_webhook.py index abc8900c82d66..abb9bf70bc0d8 100644 --- a/tests/providers/slack/operators/test_slack_webhook.py +++ b/tests/providers/slack/operators/test_slack_webhook.py @@ -23,6 +23,8 @@ from airflow.providers.slack.operators.slack_webhook import SlackWebhookOperator +DEFAULT_HOOKS_PARAMETERS = {"timeout": None, "proxy": None, "retry_handlers": None} + class TestSlackWebhookOperator: def setup_method(self): @@ -34,14 +36,28 @@ def setup_method(self): "icon_url": None, } - @pytest.mark.parametrize("proxy", [None, "https://localhost:9999"]) @mock.patch("airflow.providers.slack.operators.slack_webhook.SlackWebhookHook") - def test_hook(self, mock_slackwebhook_cls, proxy): + @pytest.mark.parametrize( + "slack_op_kwargs, hook_extra_kwargs", + [ + pytest.param({}, DEFAULT_HOOKS_PARAMETERS, id="default-hook-parameters"), + pytest.param( + {"timeout": 42, "proxy": "http://spam.egg", "retry_handlers": []}, + {"timeout": 42, "proxy": "http://spam.egg", "retry_handlers": []}, + id="with-extra-hook-parameters", + ), + ], + ) + def test_hook(self, mock_slackwebhook_cls, slack_op_kwargs, hook_extra_kwargs): """Test get cached ``SlackWebhookHook`` hook.""" - op = SlackWebhookOperator(task_id="test_hook", slack_webhook_conn_id="test_conn_id", proxy=proxy) + op = SlackWebhookOperator( + task_id="test_hook", slack_webhook_conn_id="test_conn_id", **slack_op_kwargs + ) hook = op.hook assert hook is op.hook, "Expected cached hook" - mock_slackwebhook_cls.assert_called_once_with(slack_webhook_conn_id="test_conn_id", proxy=proxy) + mock_slackwebhook_cls.assert_called_once_with( + slack_webhook_conn_id="test_conn_id", **hook_extra_kwargs + ) def test_assert_templated_fields(self): """Test expected templated fields.""" diff --git a/tests/providers/slack/transfers/test_sql_to_slack.py b/tests/providers/slack/transfers/test_sql_to_slack.py index 07e6e87aba399..c2d698719f736 100644 --- a/tests/providers/slack/transfers/test_sql_to_slack.py +++ b/tests/providers/slack/transfers/test_sql_to_slack.py @@ -109,6 +109,7 @@ def test_get_query_results(self, mock_op_get_hook, sql, parameters): class TestSqlToSlackOperator: def setup_method(self): self.example_dag = DAG(TEST_DAG_ID, start_date=DEFAULT_DATE) + self.default_hook_parameters = {"timeout": None, "proxy": None, "retry_handlers": None} @staticmethod def _construct_operator(**kwargs): @@ -116,7 +117,20 @@ def _construct_operator(**kwargs): return operator @mock.patch("airflow.providers.slack.transfers.sql_to_slack.SlackWebhookHook") - def test_rendering_and_message_execution(self, mock_slack_hook_class): + @pytest.mark.parametrize( + "slack_op_kwargs, hook_extra_kwargs", + [ + pytest.param( + {}, {"timeout": None, "proxy": None, "retry_handlers": None}, id="default-hook-parameters" + ), + pytest.param( + {"slack_timeout": 42, "slack_proxy": "http://spam.egg", "slack_retry_handlers": []}, + {"timeout": 42, "proxy": "http://spam.egg", "retry_handlers": []}, + id="with-extra-hook-parameters", + ), + ], + ) + def test_rendering_and_message_execution(self, mock_slack_hook_class, slack_op_kwargs, hook_extra_kwargs): mock_dbapi_hook = mock.Mock() test_df = pd.DataFrame({"a": "1", "b": "2"}, index=[0, 1]) @@ -130,6 +144,7 @@ def test_rendering_and_message_execution(self, mock_slack_hook_class): "slack_channel": "#test", "sql": "sql {{ ds }}", "dag": self.example_dag, + **slack_op_kwargs, } sql_to_slack_operator = self._construct_operator(**operator_args) @@ -138,7 +153,9 @@ def test_rendering_and_message_execution(self, mock_slack_hook_class): sql_to_slack_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) # Test that the Slack hook is instantiated with the right parameters - mock_slack_hook_class.assert_called_once_with(slack_webhook_conn_id="slack_connection") + mock_slack_hook_class.assert_called_once_with( + slack_webhook_conn_id="slack_connection", **hook_extra_kwargs + ) # Test that the `SlackWebhookHook.send` method gets run once slack_webhook_hook.send.assert_called_once_with( @@ -169,7 +186,9 @@ def test_rendering_and_message_execution_with_slack_hook(self, mock_slack_hook_c sql_to_slack_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) # Test that the Slack hook is instantiated with the right parameters - mock_slack_hook_class.assert_called_once_with(slack_webhook_conn_id="slack_connection") + mock_slack_hook_class.assert_called_once_with( + slack_webhook_conn_id="slack_connection", **self.default_hook_parameters + ) # Test that the `SlackWebhookHook.send` method gets run once slack_webhook_hook.send.assert_called_once_with( @@ -210,7 +229,9 @@ def test_rendering_custom_df_name_message_execution(self, mock_slack_hook_class) sql_to_slack_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) # Test that the Slack hook is instantiated with the right parameters - mock_slack_hook_class.assert_called_once_with(slack_webhook_conn_id="slack_connection") + mock_slack_hook_class.assert_called_once_with( + slack_webhook_conn_id="slack_connection", **self.default_hook_parameters + ) # Test that the `SlackWebhookHook.send` method gets run once slack_webhook_hook.send.assert_called_once_with( @@ -308,6 +329,31 @@ def setup_method(self): @pytest.mark.parametrize("channels", ["#random", "#random,#general", None]) @pytest.mark.parametrize("initial_comment", [None, "Test Comment"]) @pytest.mark.parametrize("title", [None, "Test File Title"]) + @pytest.mark.parametrize( + "slack_op_kwargs, hook_extra_kwargs", + [ + pytest.param( + {}, + {"base_url": None, "timeout": None, "proxy": None, "retry_handlers": None}, + id="default-hook-parameters", + ), + pytest.param( + { + "slack_base_url": "https://foo.bar", + "slack_timeout": 42, + "slack_proxy": "http://spam.egg", + "slack_retry_handlers": [], + }, + { + "base_url": "https://foo.bar", + "timeout": 42, + "proxy": "http://spam.egg", + "retry_handlers": [], + }, + id="with-extra-hook-parameters", + ), + ], + ) def test_send_file( self, mock_slack_hook_cls, @@ -318,6 +364,8 @@ def test_send_file( channels, initial_comment, title, + slack_op_kwargs: dict, + hook_extra_kwargs: dict, ): # Mock Hook mock_send_file = mock.MagicMock() @@ -337,11 +385,14 @@ def test_send_file( "slack_initial_comment": initial_comment, "slack_title": title, "df_kwargs": df_kwargs, + **slack_op_kwargs, } op = SqlToSlackApiFileOperator(task_id="test_send_file", **op_kwargs) op.execute(mock.MagicMock()) - mock_slack_hook_cls.assert_called_once_with(slack_conn_id="expected-test-slack-conn-id") + mock_slack_hook_cls.assert_called_once_with( + slack_conn_id="expected-test-slack-conn-id", **hook_extra_kwargs + ) mock_get_query_results.assert_called_once_with() mock_df_output_method.assert_called_once_with(mock.ANY, **(df_kwargs or {})) mock_send_file.assert_called_once_with(