Skip to content

Commit

Permalink
Add check to skip toggling CloudWatch alarms (#3682)
Browse files Browse the repository at this point in the history
Co-authored-by: Staci Mullins <[email protected]>
Co-authored-by: Madison Swain-Bowden <[email protected]>
  • Loading branch information
3 people authored Feb 14, 2024
1 parent b4b0cc9 commit 002066f
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 10 deletions.
72 changes: 72 additions & 0 deletions catalog/dags/common/cloudwatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
CloudwatchWrapper extracted partially from
https://github.com/awsdocs/aws-doc-sdk-examples/blob/54c3b82d8f9a12a862f9fcec44909829bda849af/python/example_code/cloudwatch/cloudwatch_basics.py
The CloudwatchWrapper requires the AWS_CLOUDWATCH_CONN_ID, or the `aws_default`
connection, to be set in the Airflow Connections.
Modifying alarms can be skipped by setting the `TOGGLE_CLOUDWATCH_ALARMS` to `False`
in the Airflow Variables, which is particularly the desired behavior when running
the Data Refresh DAGs locally or in a development environment.
"""
import logging

from airflow.exceptions import AirflowSkipException
from airflow.models import Variable
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from botocore.exceptions import ClientError

from common.constants import AWS_CLOUDWATCH_CONN_ID


logger = logging.getLogger(__name__)


class CloudWatchWrapper:
"""Encapsulates Amazon CloudWatch functions"""

def __init__(self, cloudwatch_resource):
""":param cloudwatch_resource: A Boto3 CloudWatch resource."""
self.cloudwatch_resource = cloudwatch_resource

def enable_alarm_actions(self, alarm_name, enable):
"""
Enable or disable actions on the specified alarm. Alarm actions can be
used to send notifications or automate responses when an alarm enters a
particular state.
:param alarm_name: The name of the alarm.
:param enable: When True, actions are enabled for the alarm. Otherwise, they
disabled.
"""
try:
alarm = self.cloudwatch_resource.Alarm(alarm_name)
if enable:
alarm.enable_actions()
else:
alarm.disable_actions()
logger.info(
"%s actions for alarm %s.",
"Enabled" if enable else "Disabled",
alarm_name,
)
except ClientError:
logger.exception(
"Couldn't %s actions alarm %s.",
"enable" if enable else "disable",
alarm_name,
)
raise


def enable_or_disable_alarms(enable):
toggle = Variable.get("TOGGLE_CLOUDWATCH_ALARMS", True, deserialize_json=True)
if not toggle:
raise AirflowSkipException("TOGGLE_CLOUDWATCH_ALARMS is set to False.")

cloudwatch = AwsBaseHook(
aws_conn_id=AWS_CLOUDWATCH_CONN_ID,
resource_type="cloudwatch",
)
cw_wrapper = CloudWatchWrapper(cloudwatch.get_conn())
cw_wrapper.enable_alarm_actions("ES Production CPU utilization above 50%", enable)
1 change: 1 addition & 0 deletions catalog/dags/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
OPENLEDGER_API_CONN_ID = "postgres_openledger_api"
POSTGRES_API_STAGING_CONN_ID = "postgres_openledger_api_staging"
AWS_CONN_ID = "aws_default"
AWS_CLOUDWATCH_CONN_ID = os.environ.get("AWS_CLOUDWATCH_CONN_ID", AWS_CONN_ID)
AWS_RDS_CONN_ID = os.environ.get("AWS_RDS_CONN_ID", AWS_CONN_ID)
ES_PROD_HTTP_CONN_ID = "elasticsearch_http_production"
REFRESH_POKE_INTERVAL = int(os.getenv("DATA_REFRESH_POKE_INTERVAL", 60 * 30))
Expand Down
12 changes: 10 additions & 2 deletions catalog/dags/common/ingestion_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from airflow.exceptions import AirflowSkipException
from airflow.providers.http.operators.http import HttpOperator
from airflow.providers.http.sensors.http import HttpSensor
from airflow.utils.trigger_rule import TriggerRule
from requests import Response

from common.constants import (
Expand Down Expand Up @@ -111,6 +112,7 @@ def trigger_task(
model: str,
data: dict | None = None,
http_conn_id: str = "data_refresh",
trigger_rule: TriggerRule = TriggerRule.ALL_SUCCESS,
) -> HttpOperator:
data = {
**(data or {}),
Expand All @@ -124,6 +126,7 @@ def trigger_task(
data=data,
response_check=lambda response: response.status_code == 202,
response_filter=response_filter_status_check_endpoint,
trigger_rule=trigger_rule,
)


Expand All @@ -133,6 +136,7 @@ def wait_for_task(
timeout: timedelta,
poke_interval: int = REFRESH_POKE_INTERVAL,
http_conn_id: str = "data_refresh",
trigger_rule: TriggerRule = TriggerRule.ALL_SUCCESS,
) -> HttpSensor:
return HttpSensor(
task_id=f"wait_for_{action.lower()}",
Expand All @@ -143,6 +147,7 @@ def wait_for_task(
mode="reschedule",
poke_interval=poke_interval,
timeout=timeout.total_seconds(),
trigger_rule=trigger_rule,
)


Expand All @@ -153,9 +158,12 @@ def trigger_and_wait_for_task(
data: dict | None = None,
poke_interval: int = REFRESH_POKE_INTERVAL,
http_conn_id: str = "data_refresh",
trigger_rule: TriggerRule = TriggerRule.ALL_SUCCESS,
) -> tuple[HttpOperator, HttpSensor]:
trigger = trigger_task(action, model, data, http_conn_id)
waiter = wait_for_task(action, trigger, timeout, poke_interval, http_conn_id)
trigger = trigger_task(action, model, data, http_conn_id, trigger_rule)
waiter = wait_for_task(
action, trigger, timeout, poke_interval, http_conn_id, trigger_rule
)
trigger >> waiter
return trigger, waiter

Expand Down
34 changes: 26 additions & 8 deletions catalog/dags/data_refresh/data_refresh_task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
successful response will include the `status_check` url used to check on the
status of the refresh, which is passed on to the next task via XCom.
3. Finally the `wait_for_data_refresh` task waits for the data refresh to be
3. Finally, the `wait_for_data_refresh` task waits for the data refresh to be
complete by polling the `status_url`. Note this task does not need to be
able to suspend itself and free the worker slot, because we want to lock the
entire pool on waiting for a particular data refresh to run.
Expand All @@ -49,10 +49,11 @@
from collections.abc import Sequence

from airflow.models.baseoperator import chain
from airflow.operators.python import PythonOperator
from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule

from common import ingestion_server
from common import cloudwatch, ingestion_server
from common.constants import PRODUCTION, XCOM_PULL_TEMPLATE
from common.sensors.single_run_external_dags_sensor import SingleRunExternalDAGsSensor
from common.sensors.utils import wait_for_external_dags
Expand Down Expand Up @@ -137,7 +138,14 @@ def create_data_refresh_task_group(
generate_index_suffix = ingestion_server.generate_index_suffix.override(
trigger_rule=TriggerRule.NONE_FAILED,
)()
tasks.append(generate_index_suffix)
disable_alarms = PythonOperator(
task_id="disable_sensitive_cloudwatch_alarms",
python_callable=cloudwatch.enable_or_disable_alarms,
op_kwargs={
"enable": False,
},
)
tasks.append([generate_index_suffix, disable_alarms])

# Trigger the 'ingest_upstream' task on the ingestion server and await its
# completion. This task copies the media table for the given model from the
Expand All @@ -147,10 +155,9 @@ def create_data_refresh_task_group(
ingestion_server.trigger_and_wait_for_task(
action="ingest_upstream",
model=data_refresh.media_type,
data={
"index_suffix": generate_index_suffix,
},
data={"index_suffix": generate_index_suffix},
timeout=data_refresh.data_refresh_timeout,
trigger_rule=TriggerRule.NONE_FAILED,
)
tasks.append(ingest_upstream_tasks)

Expand Down Expand Up @@ -181,6 +188,15 @@ def create_data_refresh_task_group(
# running against an index that is already promoted in production.
tasks.append(create_filtered_index)

enable_alarms = PythonOperator(
task_id="enable_sensitive_cloudwatch_alarms",
python_callable=cloudwatch.enable_or_disable_alarms,
op_kwargs={
"enable": True,
},
trigger_rule=TriggerRule.ALL_DONE,
)

# Trigger the `promote` task on the ingestion server and await its completion.
# This task promotes the newly created API DB table and elasticsearch index. It
# does not include promotion of the filtered index, which must be promoted
Expand All @@ -195,7 +211,7 @@ def create_data_refresh_task_group(
},
timeout=data_refresh.data_refresh_timeout,
)
tasks.append(promote_tasks)
tasks.append([enable_alarms, promote_tasks])

# Delete the alias' previous target index, now unused.
delete_old_index = ingestion_server.trigger_task(
Expand All @@ -206,6 +222,7 @@ def create_data_refresh_task_group(
get_current_index.task_id, "return_value"
),
},
trigger_rule=TriggerRule.NONE_FAILED,
)
tasks.append(delete_old_index)

Expand All @@ -219,7 +236,8 @@ def create_data_refresh_task_group(
# └─ create_filtered_index
# └─ promote (trigger_promote + wait_for_promote)
# └─ delete_old_index
# └─ promote_filtered_index (including delete filtered index)
# └─ promote_filtered_index (including delete filtered index) +
# enable_alarms
chain(*tasks)

return data_refresh_group
8 changes: 8 additions & 0 deletions catalog/env.template
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ TEST_CONN_ID=postgres_openledger_testing
AIRFLOW_CONN_ELASTICSEARCH_HTTP_PRODUCTION=http://es:9200
AIRFLOW_CONN_ELASTICSEARCH_HTTP_STAGING=http://es:9200

# AWS CloudWatch connection. Change the following line to toggle alarms during a Data Refresh.
# AIRFLOW_CONN_AWS_CLOUDWATCH=aws://<key>:<secret>@?region_name=us-east-1

# API DB connection. Change the following line in prod to use the appropriate DB
AIRFLOW_CONN_POSTGRES_OPENLEDGER_API=postgres://deploy:deploy@db:5432/openledger

Expand Down Expand Up @@ -92,6 +95,7 @@ [email protected]
# AWS/S3 configuration - does not need to be changed for development
AWS_ACCESS_KEY=test_key
AWS_SECRET_KEY=test_secret
AWS_DEFAULT_REGION=us-east-1
# General bucket used for TSV->DB ingestion and logging
OPENVERSE_BUCKET=openverse-storage
# Seconds to wait before poking for availability of the data refresh pool when running a data_refresh
Expand All @@ -117,3 +121,7 @@ SQLALCHEMY_SILENCE_UBER_WARNING=1

AIRFLOW_VAR_AIRFLOW_RDS_ARN=unset
AIRFLOW_VAR_AIRFLOW_RDS_SNAPSHOTS_TO_RETAIN=7

# Whether to toggle production CloudWatch alarms when running a data refresh DAG.
# Used to prevent requiring AWS credentials when running locally.
AIRFLOW_VAR_TOGGLE_CLOUDWATCH_ALARMS=false

0 comments on commit 002066f

Please sign in to comment.