Skip to content

Commit

Permalink
Add slack message on TSV load complete (#369)
Browse files Browse the repository at this point in the history
* Create a stubbed task to report load completion

* Add logging for pull_data task duration

* Add logging for record_count of rows upserted

* Update task dependencies

* Remove unnecessary xcom push and use default

* Update message, remove unneeded logger config

* Use existing provider_name variable

* Update push_output_paths_wrapper test with expected number of xcoms

* Send slack message

* Fix broken formatting

* Clean up tests

* Italicize note in the slack message

* Change name of slack override variable to be more descriptive

* Put all slack messages behind the environment check

* Update send message mock in tests

* Better default when duration is none
  • Loading branch information
stacimc authored Feb 23, 2022
1 parent 9538f38 commit 5188b38
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 38 deletions.
30 changes: 27 additions & 3 deletions openverse_catalog/dags/common/dag_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import inspect
import logging
import os
import time
from datetime import datetime, timedelta
from typing import Callable, Dict, List, Optional, Sequence

Expand All @@ -65,7 +66,7 @@
from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule
from common import slack
from common.loader import loader, s3, sql
from common.loader import loader, reporting, s3, sql


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -126,9 +127,17 @@ def _push_output_paths_wrapper(
ti.xcom_push(key=f"{media_type}_tsv", value=store.output_path)

logger.info("Running provider function")

start_time = time.perf_counter()
# Not passing kwargs here because Airflow throws a bunch of stuff in there that none
# of our provider scripts are expecting.
return func(*args)
data = func(*args)
end_time = time.perf_counter()

duration = end_time - start_time
ti.xcom_push(key="duration", value=duration)

return data


def create_provider_api_workflow(
Expand Down Expand Up @@ -255,6 +264,20 @@ def create_provider_api_workflow(
"identifier": identifier,
},
)
report_load_completion = PythonOperator(
task_id="report_load_completion",
python_callable=reporting.report_completion,
op_kwargs={
"provider_name": provider_name,
"media_type": media_type,
"duration": XCOM_PULL_TEMPLATE.format(
pull_data.task_id, "duration"
),
"record_count": XCOM_PULL_TEMPLATE.format(
load_from_s3.task_id, "return_value"
),
},
)
drop_loading_table = PythonOperator(
task_id="drop_loading_table",
python_callable=sql.drop_load_table,
Expand All @@ -265,7 +288,8 @@ def create_provider_api_workflow(
},
trigger_rule=TriggerRule.ALL_DONE,
)
[create_loading_table, copy_to_s3] >> load_from_s3 >> drop_loading_table
[create_loading_table, copy_to_s3] >> load_from_s3
load_from_s3 >> [report_load_completion, drop_loading_table]

pull_data >> load_data

Expand Down
3 changes: 2 additions & 1 deletion openverse_catalog/dags/common/loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def load_from_s3(
sql.load_s3_data_to_intermediate_table(
postgres_conn_id, bucket, key, identifier, media_type
)
sql.upsert_records_to_db_table(
# Returns record count
return sql.upsert_records_to_db_table(
postgres_conn_id, identifier, media_type=media_type, tsv_version=tsv_version
)
29 changes: 29 additions & 0 deletions openverse_catalog/dags/common/loader/reporting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import logging

from common.slack import send_message


logger = logging.getLogger(__name__)


def report_completion(provider_name, media_type, duration, record_count):
"""
Send a Slack notification when the load_data task has completed.
Messages are only sent out in production and if a Slack connection is defined.
In all cases the data is logged.
"""

# This happens when the task is manually set to `success` in Airflow before
# completing.
duration = "_No data_" if duration == "None" else duration

message = f"""
*Provider*: `{provider_name}`
*Media Type*: `{media_type}`
*Number of Records Upserted*: {record_count}
*Duration of data pull task*: {duration}
* _Duration includes time taken to pull data of all media types._
"""
send_message(message, username="Airflow DAG Load Data Complete")
logger.info(message)
2 changes: 1 addition & 1 deletion openverse_catalog/dags/common/loader/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def upsert_records_to_db_table(
{upsert_conflict_string}
"""
)
postgres.run(upsert_query)
return postgres.run(upsert_query, handler=lambda c: c.rowcount)


def drop_load_table(
Expand Down
36 changes: 26 additions & 10 deletions openverse_catalog/dags/common/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
This class is intended to be used with a channel-specific slack webhook.
More information can be found here: https://app.slack.com/block-kit-builder.
## Messages are not configured to send in development
Messages or alerts sent using `send_message` or `on_failure_callback` will only
send if a Slack connection is defined and we are running in production. You can
manually override this for testing purposes by setting the `slack_message_override`
variable to `true` in the Airflow UI.
## Send multiple messages - payload is reset after sending
>>> slack = SlackMessage(username="Multi-message Test")
Expand Down Expand Up @@ -215,31 +222,38 @@ def send_message(
http_conn_id: str = SLACK_NOTIFICATIONS_CONN_ID,
) -> None:
"""Send a simple slack message, convenience message for short/simple messages."""
if not should_send_message(http_conn_id):
return
s = SlackMessage(username, icon_emoji, http_conn_id=http_conn_id)
s.add_text(text, plain_text=not markdown)
s.send(text)


def on_failure_callback(context: dict) -> None:
def should_send_message(http_conn_id=SLACK_NOTIFICATIONS_CONN_ID):
"""
Send an alert out regarding a failure to Slack.
Errors are only sent out in production and if a Slack connection is defined.
Returns true if a Slack connection is defined and we are in production (or
the message override is set).
"""
# Exit early if no slack connection exists
hook = HttpHook(http_conn_id=SLACK_ALERTS_CONN_ID)
hook = HttpHook(http_conn_id=http_conn_id)
try:
hook.get_conn()
except AirflowNotFoundException:
return
return False

# Exit early if we aren't on production or if force alert is not set
environment = Variable.get("environment", default_var="dev")
force_alert = Variable.get(
"force_slack_alert", default_var=False, deserialize_json=True
force_message = Variable.get(
"slack_message_override", default_var=False, deserialize_json=True
)
if not (environment == "prod" or force_alert):
return
return environment == "prod" or force_message


def on_failure_callback(context: dict) -> None:
"""
Send an alert out regarding a failure to Slack.
Errors are only sent out in production and if a Slack connection is defined.
"""
# Get relevant info
ti = context["task_instance"]
execution_date = context["execution_date"]
Expand All @@ -263,4 +277,6 @@ def on_failure_callback(context: dict) -> None:
*Log*: {ti.log_url}
{exception_message}
"""
send_message(message, username="Airflow DAG Failure")
send_message(
message, username="Airflow DAG Failure", http_conn_id=SLACK_ALERTS_CONN_ID
)
23 changes: 23 additions & 0 deletions tests/dags/common/loader/test_reporting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from unittest import mock

import pytest
from common.loader.reporting import report_completion


@pytest.fixture(autouse=True)
def send_message_mock() -> mock.MagicMock:
with mock.patch("common.slack.SlackMessage.send") as SendMessageMock:
yield SendMessageMock


@pytest.mark.parametrize(
"should_send_message",
[True, False],
)
def test_report_completion(should_send_message):
with mock.patch(
"common.slack.should_send_message", return_value=should_send_message
):
report_completion("Jamendo", "Audio", None, 100)
# Send message is only called if `should_send_message` is True.
send_message_mock.called = should_send_message
15 changes: 11 additions & 4 deletions tests/dags/common/test_dag_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,18 @@ def test_push_output_paths_wrapper(func, media_types, stores):
ti_mock,
args=[func_mock, value],
)
assert ti_mock.xcom_push.call_count == len(
media_types
), "# of output paths didn't match # of stores expected"
for args, store in zip(ti_mock.xcom_push.calls, stores):
# There should be one call to xcom_push for each provided store, plus
# one final call to report the task duration.
expected_xcoms = len(media_types) + 1
actual_xcoms = ti_mock.xcom_push.call_count
assert (
actual_xcoms == expected_xcoms
), f"Expected {expected_xcoms} XComs but {actual_xcoms} pushed"
for args, store in zip(ti_mock.xcom_push.mock_calls[:-1], stores):
assert args.kwargs["value"] == store.output_path
# Check that the duration was reported
assert ti_mock.xcom_push.mock_calls[-1].kwargs["key"] == "duration"

# Check that the function itself was called with the provided args
func_mock.assert_called_once_with(value)

Expand Down
65 changes: 46 additions & 19 deletions tests/dags/common/test_slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

import pytest
from airflow.exceptions import AirflowNotFoundException
from common.slack import SlackMessage, on_failure_callback, send_message
from common.slack import (
SlackMessage,
on_failure_callback,
send_message,
should_send_message,
)


_FAKE_IMAGE = "http://image.com/img.jpg"
Expand Down Expand Up @@ -279,20 +284,48 @@ def test_send_fails(http_hook_mock):
s.send()


@pytest.mark.parametrize(
"environment, slack_message_override, expected_result",
[
("dev", False, False),
("dev", True, True),
("prod", False, True),
("prod", True, True),
],
)
def test_should_send_message(environment, slack_message_override, expected_result):
with mock.patch("common.slack.Variable") as MockVariable:
# Mock the calls to Variable.get, in order
MockVariable.get.side_effect = [environment, slack_message_override]
assert should_send_message() == expected_result


def test_should_send_message_is_false_without_hook(http_hook_mock):
http_hook_mock.get_conn.side_effect = AirflowNotFoundException("nope")
assert not should_send_message()


def test_send_message(http_hook_mock):
send_message("Sample text", username="DifferentUser")
http_hook_mock.run.assert_called_with(
endpoint=None,
data='{"username": "DifferentUser", "unfurl_links": true, "unfurl_media": true,'
' "icon_emoji": ":airflow:", "blocks": [{"type": "section", "text": '
'{"type": "mrkdwn", "text": "Sample text"}}], "text": "Sample text"}',
headers={"Content-type": "application/json"},
extra_options={"verify": True},
)
with mock.patch("common.slack.should_send_message", return_value=True):
send_message("Sample text", username="DifferentUser")
http_hook_mock.run.assert_called_with(
endpoint=None,
data='{"username": "DifferentUser", "unfurl_links": true, "unfurl_media": true,'
' "icon_emoji": ":airflow:", "blocks": [{"type": "section", "text": '
'{"type": "mrkdwn", "text": "Sample text"}}], "text": "Sample text"}',
headers={"Content-type": "application/json"},
extra_options={"verify": True},
)


def test_send_message_does_not_send_if_checks_fail(http_hook_mock):
with mock.patch("common.slack.should_send_message", return_value=False):
send_message("Sample text", username="DifferentUser")
http_hook_mock.run.assert_not_called()


@pytest.mark.parametrize(
"exception, environment, force_slack_alert, call_expected",
"exception, environment, slack_message_override, call_expected",
[
# Message with exception
(ValueError("Whoops!"), "dev", False, False),
Expand All @@ -312,7 +345,7 @@ def test_send_message(http_hook_mock):
],
)
def test_on_failure_callback(
exception, environment, force_slack_alert, call_expected, http_hook_mock
exception, environment, slack_message_override, call_expected, http_hook_mock
):
context = {
"task_instance": mock.Mock(),
Expand All @@ -322,17 +355,11 @@ def test_on_failure_callback(
with mock.patch("common.slack.Variable") as MockVariable:
run_mock = http_hook_mock.run
# Mock the calls to Variable.get, in order
MockVariable.get.side_effect = [environment, force_slack_alert]
MockVariable.get.side_effect = [environment, slack_message_override]
on_failure_callback(context)
assert run_mock.called == call_expected
if call_expected:
# Check that an exception message is present only if one is provided
assert bool(exception) ^ (
"Exception" not in run_mock.call_args.kwargs["data"]
)


def test_on_failure_callback_does_nothing_without_hook(http_hook_mock):
http_hook_mock.get_conn.side_effect = AirflowNotFoundException("nope")
on_failure_callback({})
http_hook_mock.run.assert_not_called()

0 comments on commit 5188b38

Please sign in to comment.