diff --git a/catalog/dags/common/cloudwatch.py b/catalog/dags/common/cloudwatch.py new file mode 100644 index 00000000000..0be3df340a3 --- /dev/null +++ b/catalog/dags/common/cloudwatch.py @@ -0,0 +1,72 @@ +""" +CloudwatchWrapper extracted partially from +https://github.com/awsdocs/aws-doc-sdk-examples/blob/54c3b82d8f9a12a862f9fcec44909829bda849af/python/example_code/cloudwatch/cloudwatch_basics.py + +The CloudwatchWrapper requires the AWS_CLOUDWATCH_CONN_ID, or the `aws_default` +connection, to be set in the Airflow Connections. + +Modifying alarms can be skipped by setting the `TOGGLE_CLOUDWATCH_ALARMS` to `False` +in the Airflow Variables, which is particularly the desired behavior when running +the Data Refresh DAGs locally or in a development environment. +""" +import logging + +from airflow.exceptions import AirflowSkipException +from airflow.models import Variable +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from botocore.exceptions import ClientError + +from common.constants import AWS_CLOUDWATCH_CONN_ID + + +logger = logging.getLogger(__name__) + + +class CloudWatchWrapper: + """Encapsulates Amazon CloudWatch functions""" + + def __init__(self, cloudwatch_resource): + """:param cloudwatch_resource: A Boto3 CloudWatch resource.""" + self.cloudwatch_resource = cloudwatch_resource + + def enable_alarm_actions(self, alarm_name, enable): + """ + Enable or disable actions on the specified alarm. Alarm actions can be + used to send notifications or automate responses when an alarm enters a + particular state. + + :param alarm_name: The name of the alarm. + :param enable: When True, actions are enabled for the alarm. Otherwise, they + disabled. + """ + try: + alarm = self.cloudwatch_resource.Alarm(alarm_name) + if enable: + alarm.enable_actions() + else: + alarm.disable_actions() + logger.info( + "%s actions for alarm %s.", + "Enabled" if enable else "Disabled", + alarm_name, + ) + except ClientError: + logger.exception( + "Couldn't %s actions alarm %s.", + "enable" if enable else "disable", + alarm_name, + ) + raise + + +def enable_or_disable_alarms(enable): + toggle = Variable.get("TOGGLE_CLOUDWATCH_ALARMS", True, deserialize_json=True) + if not toggle: + raise AirflowSkipException("TOGGLE_CLOUDWATCH_ALARMS is set to False.") + + cloudwatch = AwsBaseHook( + aws_conn_id=AWS_CLOUDWATCH_CONN_ID, + resource_type="cloudwatch", + ) + cw_wrapper = CloudWatchWrapper(cloudwatch.get_conn()) + cw_wrapper.enable_alarm_actions("ES Production CPU utilization above 50%", enable) diff --git a/catalog/dags/common/constants.py b/catalog/dags/common/constants.py index c2de110341d..39c318bc5c2 100644 --- a/catalog/dags/common/constants.py +++ b/catalog/dags/common/constants.py @@ -40,6 +40,7 @@ OPENLEDGER_API_CONN_ID = "postgres_openledger_api" POSTGRES_API_STAGING_CONN_ID = "postgres_openledger_api_staging" AWS_CONN_ID = "aws_default" +AWS_CLOUDWATCH_CONN_ID = os.environ.get("AWS_CLOUDWATCH_CONN_ID", AWS_CONN_ID) AWS_RDS_CONN_ID = os.environ.get("AWS_RDS_CONN_ID", AWS_CONN_ID) ES_PROD_HTTP_CONN_ID = "elasticsearch_http_production" REFRESH_POKE_INTERVAL = int(os.getenv("DATA_REFRESH_POKE_INTERVAL", 60 * 30)) diff --git a/catalog/dags/common/ingestion_server.py b/catalog/dags/common/ingestion_server.py index e8ba9f4c42d..a2b722248af 100644 --- a/catalog/dags/common/ingestion_server.py +++ b/catalog/dags/common/ingestion_server.py @@ -8,6 +8,7 @@ from airflow.exceptions import AirflowSkipException from airflow.providers.http.operators.http import HttpOperator from airflow.providers.http.sensors.http import HttpSensor +from airflow.utils.trigger_rule import TriggerRule from requests import Response from common.constants import ( @@ -111,6 +112,7 @@ def trigger_task( model: str, data: dict | None = None, http_conn_id: str = "data_refresh", + trigger_rule: TriggerRule = TriggerRule.ALL_SUCCESS, ) -> HttpOperator: data = { **(data or {}), @@ -124,6 +126,7 @@ def trigger_task( data=data, response_check=lambda response: response.status_code == 202, response_filter=response_filter_status_check_endpoint, + trigger_rule=trigger_rule, ) @@ -133,6 +136,7 @@ def wait_for_task( timeout: timedelta, poke_interval: int = REFRESH_POKE_INTERVAL, http_conn_id: str = "data_refresh", + trigger_rule: TriggerRule = TriggerRule.ALL_SUCCESS, ) -> HttpSensor: return HttpSensor( task_id=f"wait_for_{action.lower()}", @@ -143,6 +147,7 @@ def wait_for_task( mode="reschedule", poke_interval=poke_interval, timeout=timeout.total_seconds(), + trigger_rule=trigger_rule, ) @@ -153,9 +158,12 @@ def trigger_and_wait_for_task( data: dict | None = None, poke_interval: int = REFRESH_POKE_INTERVAL, http_conn_id: str = "data_refresh", + trigger_rule: TriggerRule = TriggerRule.ALL_SUCCESS, ) -> tuple[HttpOperator, HttpSensor]: - trigger = trigger_task(action, model, data, http_conn_id) - waiter = wait_for_task(action, trigger, timeout, poke_interval, http_conn_id) + trigger = trigger_task(action, model, data, http_conn_id, trigger_rule) + waiter = wait_for_task( + action, trigger, timeout, poke_interval, http_conn_id, trigger_rule + ) trigger >> waiter return trigger, waiter diff --git a/catalog/dags/data_refresh/data_refresh_task_factory.py b/catalog/dags/data_refresh/data_refresh_task_factory.py index 0531746aaf0..8444f0641c0 100644 --- a/catalog/dags/data_refresh/data_refresh_task_factory.py +++ b/catalog/dags/data_refresh/data_refresh_task_factory.py @@ -33,7 +33,7 @@ successful response will include the `status_check` url used to check on the status of the refresh, which is passed on to the next task via XCom. -3. Finally the `wait_for_data_refresh` task waits for the data refresh to be +3. Finally, the `wait_for_data_refresh` task waits for the data refresh to be complete by polling the `status_url`. Note this task does not need to be able to suspend itself and free the worker slot, because we want to lock the entire pool on waiting for a particular data refresh to run. @@ -49,10 +49,11 @@ from collections.abc import Sequence from airflow.models.baseoperator import chain +from airflow.operators.python import PythonOperator from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule -from common import ingestion_server +from common import cloudwatch, ingestion_server from common.constants import PRODUCTION, XCOM_PULL_TEMPLATE from common.sensors.single_run_external_dags_sensor import SingleRunExternalDAGsSensor from common.sensors.utils import wait_for_external_dags @@ -137,7 +138,14 @@ def create_data_refresh_task_group( generate_index_suffix = ingestion_server.generate_index_suffix.override( trigger_rule=TriggerRule.NONE_FAILED, )() - tasks.append(generate_index_suffix) + disable_alarms = PythonOperator( + task_id="disable_sensitive_cloudwatch_alarms", + python_callable=cloudwatch.enable_or_disable_alarms, + op_kwargs={ + "enable": False, + }, + ) + tasks.append([generate_index_suffix, disable_alarms]) # Trigger the 'ingest_upstream' task on the ingestion server and await its # completion. This task copies the media table for the given model from the @@ -147,10 +155,9 @@ def create_data_refresh_task_group( ingestion_server.trigger_and_wait_for_task( action="ingest_upstream", model=data_refresh.media_type, - data={ - "index_suffix": generate_index_suffix, - }, + data={"index_suffix": generate_index_suffix}, timeout=data_refresh.data_refresh_timeout, + trigger_rule=TriggerRule.NONE_FAILED, ) tasks.append(ingest_upstream_tasks) @@ -181,6 +188,15 @@ def create_data_refresh_task_group( # running against an index that is already promoted in production. tasks.append(create_filtered_index) + enable_alarms = PythonOperator( + task_id="enable_sensitive_cloudwatch_alarms", + python_callable=cloudwatch.enable_or_disable_alarms, + op_kwargs={ + "enable": True, + }, + trigger_rule=TriggerRule.ALL_DONE, + ) + # Trigger the `promote` task on the ingestion server and await its completion. # This task promotes the newly created API DB table and elasticsearch index. It # does not include promotion of the filtered index, which must be promoted @@ -195,7 +211,7 @@ def create_data_refresh_task_group( }, timeout=data_refresh.data_refresh_timeout, ) - tasks.append(promote_tasks) + tasks.append([enable_alarms, promote_tasks]) # Delete the alias' previous target index, now unused. delete_old_index = ingestion_server.trigger_task( @@ -206,6 +222,7 @@ def create_data_refresh_task_group( get_current_index.task_id, "return_value" ), }, + trigger_rule=TriggerRule.NONE_FAILED, ) tasks.append(delete_old_index) @@ -219,7 +236,8 @@ def create_data_refresh_task_group( # └─ create_filtered_index # └─ promote (trigger_promote + wait_for_promote) # └─ delete_old_index - # └─ promote_filtered_index (including delete filtered index) + # └─ promote_filtered_index (including delete filtered index) + + # enable_alarms chain(*tasks) return data_refresh_group diff --git a/catalog/env.template b/catalog/env.template index 93fed375936..7125009f6ec 100644 --- a/catalog/env.template +++ b/catalog/env.template @@ -50,6 +50,9 @@ TEST_CONN_ID=postgres_openledger_testing AIRFLOW_CONN_ELASTICSEARCH_HTTP_PRODUCTION=http://es:9200 AIRFLOW_CONN_ELASTICSEARCH_HTTP_STAGING=http://es:9200 +# AWS CloudWatch connection. Change the following line to toggle alarms during a Data Refresh. +# AIRFLOW_CONN_AWS_CLOUDWATCH=aws://:@?region_name=us-east-1 + # API DB connection. Change the following line in prod to use the appropriate DB AIRFLOW_CONN_POSTGRES_OPENLEDGER_API=postgres://deploy:deploy@db:5432/openledger @@ -92,6 +95,7 @@ CONTACT_EMAIL=openverse@wordpress.org # AWS/S3 configuration - does not need to be changed for development AWS_ACCESS_KEY=test_key AWS_SECRET_KEY=test_secret +AWS_DEFAULT_REGION=us-east-1 # General bucket used for TSV->DB ingestion and logging OPENVERSE_BUCKET=openverse-storage # Seconds to wait before poking for availability of the data refresh pool when running a data_refresh @@ -117,3 +121,7 @@ SQLALCHEMY_SILENCE_UBER_WARNING=1 AIRFLOW_VAR_AIRFLOW_RDS_ARN=unset AIRFLOW_VAR_AIRFLOW_RDS_SNAPSHOTS_TO_RETAIN=7 + +# Whether to toggle production CloudWatch alarms when running a data refresh DAG. +# Used to prevent requiring AWS credentials when running locally. +AIRFLOW_VAR_TOGGLE_CLOUDWATCH_ALARMS=false