Skip to content

Commit

Permalink
Pass additional arguments from Slack's Operators/Notifiers to Hooks (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored Oct 19, 2023
1 parent 50f6483 commit e6f4451
Show file tree
Hide file tree
Showing 10 changed files with 261 additions and 28 deletions.
27 changes: 24 additions & 3 deletions airflow/providers/slack/notifications/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@

import json
from functools import cached_property
from typing import Sequence
from typing import TYPE_CHECKING, Sequence

from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers.slack.hooks.slack import SlackHook

try:
from airflow.notifications.basenotifier import BaseNotifier
Expand All @@ -30,7 +31,8 @@
"Failed to import BaseNotifier. This feature is only available in Airflow versions >= 2.6.0"
)

from airflow.providers.slack.hooks.slack import SlackHook
if TYPE_CHECKING:
from slack_sdk.http_retry import RetryHandler

ICON_URL: str = "https://raw.githubusercontent.com/apache/airflow/2.5.0/airflow/www/static/pin_100.png"

Expand All @@ -46,6 +48,11 @@ class SlackNotifier(BaseNotifier):
:param icon_url: The icon to use for the message. Optional
:param attachments: A list of attachments to send with the message. Optional
:param blocks: A list of blocks to send with the message. Optional
:param timeout: The maximum number of seconds the client will wait to connect
and receive a response from Slack. Optional
:param base_url: A string representing the Slack API base URL. Optional
:param proxy: Proxy to make the Slack API call. Optional
:param retry_handlers: List of handlers to customize retry logic in ``slack_sdk.WebClient``. Optional
"""

template_fields = ("text", "channel", "username", "attachments", "blocks")
Expand All @@ -60,6 +67,10 @@ def __init__(
icon_url: str = ICON_URL,
attachments: Sequence = (),
blocks: Sequence = (),
base_url: str | None = None,
proxy: str | None = None,
timeout: int | None = None,
retry_handlers: list[RetryHandler] | None = None,
):
super().__init__()
self.slack_conn_id = slack_conn_id
Expand All @@ -69,11 +80,21 @@ def __init__(
self.icon_url = icon_url
self.attachments = attachments
self.blocks = blocks
self.base_url = base_url
self.timeout = timeout
self.proxy = proxy
self.retry_handlers = retry_handlers

@cached_property
def hook(self) -> SlackHook:
"""Slack Hook."""
return SlackHook(slack_conn_id=self.slack_conn_id)
return SlackHook(
slack_conn_id=self.slack_conn_id,
base_url=self.base_url,
timeout=self.timeout,
proxy=self.proxy,
retry_handlers=self.retry_handlers,
)

def notify(self, context):
"""Send a message to a Slack Channel."""
Expand Down
14 changes: 12 additions & 2 deletions airflow/providers/slack/notifications/slack_webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING

from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook
Expand All @@ -29,6 +30,9 @@
"Failed to import BaseNotifier. This feature is only available in Airflow versions >= 2.6.0"
)

if TYPE_CHECKING:
from slack_sdk.http_retry import RetryHandler


class SlackWebhookNotifier(BaseNotifier):
"""
Expand All @@ -45,9 +49,10 @@ class SlackWebhookNotifier(BaseNotifier):
:param unfurl_links: Option to indicate whether text url should unfurl. Optional
:param unfurl_media: Option to indicate whether media url should unfurl. Optional
:param timeout: The maximum number of seconds the client will wait to connect. Optional
and receive a response from Slack. If not set than default WebhookClient value will use.
and receive a response from Slack. Optional
:param proxy: Proxy to make the Slack Incoming Webhook call. Optional
:param attachments: A list of attachments to send with the message. Optional
:param retry_handlers: List of handlers to customize retry logic in ``slack_sdk.WebhookClient``. Optional
"""

template_fields = ("slack_webhook_conn_id", "text", "attachments", "blocks", "proxy", "timeout")
Expand All @@ -63,6 +68,7 @@ def __init__(
proxy: str | None = None,
timeout: int | None = None,
attachments: list | None = None,
retry_handlers: list[RetryHandler] | None = None,
):
super().__init__()
self.slack_webhook_conn_id = slack_webhook_conn_id
Expand All @@ -73,12 +79,16 @@ def __init__(
self.unfurl_media = unfurl_media
self.timeout = timeout
self.proxy = proxy
self.retry_handlers = retry_handlers

@cached_property
def hook(self) -> SlackWebhookHook:
"""Slack Incoming Webhook Hook."""
return SlackWebhookHook(
slack_webhook_conn_id=self.slack_webhook_conn_id, proxy=self.proxy, timeout=self.timeout
slack_webhook_conn_id=self.slack_webhook_conn_id,
proxy=self.proxy,
timeout=self.timeout,
retry_handlers=self.retry_handlers,
)

def notify(self, context):
Expand Down
27 changes: 24 additions & 3 deletions airflow/providers/slack/operators/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
import json
import warnings
from functools import cached_property
from typing import Any, Sequence
from typing import TYPE_CHECKING, Any, Sequence

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.slack.hooks.slack import SlackHook

if TYPE_CHECKING:
from slack_sdk.http_retry import RetryHandler


class SlackAPIOperator(BaseOperator):
"""Base Slack Operator class.
Expand All @@ -34,7 +37,11 @@ class SlackAPIOperator(BaseOperator):
which its password is Slack API token.
:param method: The Slack API Method to Call (https://api.slack.com/methods). Optional
:param api_params: API Method call parameters (https://api.slack.com/methods). Optional
:param client_args: Slack Hook parameters. Optional. Check airflow.providers.slack.hooks.SlackHook
:param timeout: The maximum number of seconds the client will wait to connect
and receive a response from Slack. Optional
:param base_url: A string representing the Slack API base URL. Optional
:param proxy: Proxy to make the Slack API call. Optional
:param retry_handlers: List of handlers to customize retry logic in ``slack_sdk.WebClient``. Optional
"""

def __init__(
Expand All @@ -43,17 +50,31 @@ def __init__(
slack_conn_id: str = SlackHook.default_conn_name,
method: str | None = None,
api_params: dict | None = None,
base_url: str | None = None,
proxy: str | None = None,
timeout: int | None = None,
retry_handlers: list[RetryHandler] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.slack_conn_id = slack_conn_id
self.method = method
self.api_params = api_params
self.base_url = base_url
self.timeout = timeout
self.proxy = proxy
self.retry_handlers = retry_handlers

@cached_property
def hook(self) -> SlackHook:
"""Slack Hook."""
return SlackHook(slack_conn_id=self.slack_conn_id)
return SlackHook(
slack_conn_id=self.slack_conn_id,
base_url=self.base_url,
timeout=self.timeout,
proxy=self.proxy,
retry_handlers=self.retry_handlers,
)

def construct_api_call_params(self) -> Any:
"""API call parameters used by the execute function.
Expand Down
18 changes: 16 additions & 2 deletions airflow/providers/slack/operators/slack_webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook

if TYPE_CHECKING:
from slack_sdk.http_retry import RetryHandler

from airflow.utils.context import Context


Expand Down Expand Up @@ -51,7 +53,10 @@ class SlackWebhookOperator(BaseOperator):
:param username: The username to post to slack with
:param icon_emoji: The emoji to use as icon for the user posting to Slack
:param icon_url: The icon image URL string to use in place of the default icon.
:param proxy: Proxy to use to make the Slack webhook call
:param proxy: Proxy to make the Slack Incoming Webhook call. Optional
:param timeout: The maximum number of seconds the client will wait to connect
and receive a response from Slack. Optional
:param retry_handlers: List of handlers to customize retry logic in ``slack_sdk.WebhookClient``. Optional
"""

template_fields: Sequence[str] = (
Expand All @@ -75,6 +80,8 @@ def __init__(
icon_emoji: str | None = None,
icon_url: str | None = None,
proxy: str | None = None,
timeout: int | None = None,
retry_handlers: list[RetryHandler] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -87,11 +94,18 @@ def __init__(
self.username = username
self.icon_emoji = icon_emoji
self.icon_url = icon_url
self.timeout = timeout
self.retry_handlers = retry_handlers

@cached_property
def hook(self) -> SlackWebhookHook:
"""Create and return an SlackWebhookHook (cached)."""
return SlackWebhookHook(slack_webhook_conn_id=self.slack_webhook_conn_id, proxy=self.proxy)
return SlackWebhookHook(
slack_webhook_conn_id=self.slack_webhook_conn_id,
proxy=self.proxy,
timeout=self.timeout,
retry_handlers=self.retry_handlers,
)

def execute(self, context: Context) -> None:
"""Call the SlackWebhookHook to post the provided Slack message."""
Expand Down
30 changes: 28 additions & 2 deletions airflow/providers/slack/transfers/sql_to_slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

if TYPE_CHECKING:
import pandas as pd
from slack_sdk.http_retry import RetryHandler

from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.utils.context import Context
Expand All @@ -44,6 +45,10 @@ class BaseSqlToSlackOperator(BaseOperator):
:param sql_hook_params: Extra config params to be passed to the underlying hook.
Should match the desired hook constructor params.
:param parameters: The parameters to pass to the SQL query.
:param slack_proxy: Proxy to make the Slack Incoming Webhook / API calls. Optional
:param slack_timeout: The maximum number of seconds the client will wait to connect
and receive a response from Slack. Optional
:param slack_retry_handlers: List of handlers to customize retry logic. Optional
"""

def __init__(
Expand All @@ -53,13 +58,19 @@ def __init__(
sql_conn_id: str,
sql_hook_params: dict | None = None,
parameters: Iterable | Mapping[str, Any] | None = None,
slack_proxy: str | None = None,
slack_timeout: int | None = None,
slack_retry_handlers: list[RetryHandler] | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.sql_conn_id = sql_conn_id
self.sql_hook_params = sql_hook_params
self.sql = sql
self.parameters = parameters
self.slack_proxy = slack_proxy
self.slack_timeout = slack_timeout
self.slack_retry_handlers = slack_retry_handlers

def _get_hook(self) -> DbApiHook:
self.log.debug("Get connection for %s", self.sql_conn_id)
Expand Down Expand Up @@ -146,7 +157,12 @@ def _render_and_send_slack_message(self, context, df) -> None:
slack_hook.send(text=self.slack_message, channel=self.slack_channel)

def _get_slack_hook(self) -> SlackWebhookHook:
return SlackWebhookHook(slack_webhook_conn_id=self.slack_conn_id)
return SlackWebhookHook(
slack_webhook_conn_id=self.slack_conn_id,
proxy=self.slack_proxy,
timeout=self.slack_timeout,
retry_handlers=self.slack_retry_handlers,
)

def render_template_fields(self, context, jinja_env=None) -> None:
# If this is the first render of the template fields, exclude slack_message from rendering since
Expand Down Expand Up @@ -197,8 +213,10 @@ class SqlToSlackApiFileOperator(BaseSqlToSlackOperator):
If omitting this parameter, then file will send to workspace.
:param slack_initial_comment: The message text introducing the file in specified ``slack_channels``.
:param slack_title: Title of file.
:param slack_base_url: A string representing the Slack API base URL. Optional
:param df_kwargs: Keyword arguments forwarded to ``pandas.DataFrame.to_{format}()`` method.
Example:
.. code-block:: python
Expand Down Expand Up @@ -241,6 +259,7 @@ def __init__(
slack_channels: str | Sequence[str] | None = None,
slack_initial_comment: str | None = None,
slack_title: str | None = None,
slack_base_url: str | None = None,
df_kwargs: dict | None = None,
**kwargs,
):
Expand All @@ -252,6 +271,7 @@ def __init__(
self.slack_channels = slack_channels
self.slack_initial_comment = slack_initial_comment
self.slack_title = slack_title
self.slack_base_url = slack_base_url
self.df_kwargs = df_kwargs or {}

def execute(self, context: Context) -> None:
Expand All @@ -261,7 +281,13 @@ def execute(self, context: Context) -> None:
supported_file_formats=self.SUPPORTED_FILE_FORMATS,
)

slack_hook = SlackHook(slack_conn_id=self.slack_conn_id)
slack_hook = SlackHook(
slack_conn_id=self.slack_conn_id,
base_url=self.slack_base_url,
timeout=self.slack_timeout,
proxy=self.slack_proxy,
retry_handlers=self.slack_retry_handlers,
)
with NamedTemporaryFile(mode="w+", suffix=f"_{self.slack_filename}") as fp:
# tempfile.NamedTemporaryFile used only for create and remove temporary file,
# pandas will open file in correct mode itself depend on file type.
Expand Down
30 changes: 28 additions & 2 deletions tests/providers/slack/notifications/test_slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,42 @@

from unittest import mock

import pytest

from airflow.operators.empty import EmptyOperator
from airflow.providers.slack.notifications.slack import SlackNotifier, send_slack_notification

DEFAULT_HOOKS_PARAMETERS = {"base_url": None, "timeout": None, "proxy": None, "retry_handlers": None}


class TestSlackNotifier:
@mock.patch("airflow.providers.slack.notifications.slack.SlackHook")
def test_slack_notifier(self, mock_slack_hook, dag_maker):
@pytest.mark.parametrize(
"extra_kwargs, hook_extra_kwargs",
[
pytest.param({}, DEFAULT_HOOKS_PARAMETERS, id="default-hook-parameters"),
pytest.param(
{
"base_url": "https://foo.bar",
"timeout": 42,
"proxy": "http://spam.egg",
"retry_handlers": [],
},
{
"base_url": "https://foo.bar",
"timeout": 42,
"proxy": "http://spam.egg",
"retry_handlers": [],
},
id="with-extra-hook-parameters",
),
],
)
def test_slack_notifier(self, mock_slack_hook, dag_maker, extra_kwargs, hook_extra_kwargs):
with dag_maker("test_slack_notifier") as dag:
EmptyOperator(task_id="task1")

notifier = send_slack_notification(text="test")
notifier = send_slack_notification(slack_conn_id="test_conn_id", text="test", **extra_kwargs)
notifier({"dag": dag})
mock_slack_hook.return_value.call.assert_called_once_with(
"chat.postMessage",
Expand All @@ -43,6 +68,7 @@ def test_slack_notifier(self, mock_slack_hook, dag_maker):
"blocks": "[]",
},
)
mock_slack_hook.assert_called_once_with(slack_conn_id="test_conn_id", **hook_extra_kwargs)

@mock.patch("airflow.providers.slack.notifications.slack.SlackHook")
def test_slack_notifier_with_notifier_class(self, mock_slack_hook, dag_maker):
Expand Down
Loading

0 comments on commit e6f4451

Please sign in to comment.