Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add elasticsearch concurrency tags for Airflow #3921

Merged
merged 11 commits into from
Mar 26, 2024
19 changes: 19 additions & 0 deletions catalog/dags/common/sensors/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from common.constants import PRODUCTION, STAGING


# These DagTags are used to identify DAGs which should not be run concurrently
# with one another.

# Used to identify DAGs for each environment which affect the Elasticsearch cluster
# and should not be run simultaneously
PRODUCTION_ES_CONCURRENCY_TAG = "production_elasticsearch_concurrency"
STAGING_ES_CONCURRENCY_TAG = "staging_elasticsearch_concurrency"

# Used to identify DAGs which affect the staging API database in such a
# way that they should not be run simultaneously
STAGING_DB_CONCURRENCY_TAG = "staging_api_database_concurrency"

ES_CONCURRENCY_TAGS = {
PRODUCTION: PRODUCTION_ES_CONCURRENCY_TAG,
STAGING: STAGING_ES_CONCURRENCY_TAG,
}
108 changes: 81 additions & 27 deletions catalog/dags/common/sensors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@

from airflow.decorators import task, task_group
from airflow.exceptions import AirflowSensorTimeout
from airflow.models import DagRun
from airflow.models import DagModel, DagRun, DagTag
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.utils.session import provide_session
from airflow.utils.state import State

from common.constants import REFRESH_POKE_INTERVAL


def get_most_recent_dag_run(dag_id) -> list[datetime] | datetime:
THREE_DAYS = 60 * 60 * 24 * 3


def _get_most_recent_dag_run(dag_id) -> list[datetime] | datetime:
"""
Retrieve the most recent DAG run's execution date.

Expand All @@ -35,9 +39,40 @@ def get_most_recent_dag_run(dag_id) -> list[datetime] | datetime:
return []


def wait_for_external_dag(external_dag_id: str, task_id: str | None = None):
@task
def get_dags_with_concurrency_tag(
tag: str, excluded_dag_ids: list[str], session=None, dag=None
):
"""
Get a list of DAG ids with the given tag. The id of the running DAG is excluded,
as well as any ids in the `excluded_dag_ids` list.
"""
Return a Sensor task which will wait if the given external DAG is
dags = session.query(DagModel).filter(DagModel.tags.any(DagTag.name == tag)).all()
dag_ids = [dag.dag_id for dag in dags]

running_dag_id = dag.dag_id
if running_dag_id not in dag_ids:
raise ValueError(
f"The `{running_dag_id}` DAG tried preventing concurrency with the `{tag}`,"
" tag, but does not have the tag itself. To ensure that other DAGs with this"
f" tag will also avoid running concurrently with `{running_dag_id}`, it must"
f"have the `{tag}` tag applied."
)
krysal marked this conversation as resolved.
Show resolved Hide resolved

# Return just the ids of DAGs to prevent concurrency with. This excludes the running dag id,
# and any supplied `excluded_dag_ids`
return [id for id in dag_ids if id not in {*excluded_dag_ids, running_dag_id}]


@task
def wait_for_external_dag(
external_dag_id: str,
task_id: str | None = None,
timeout: int | None = THREE_DAYS,
**context,
):
"""
Execute a Sensor task which will wait if the given external DAG is
running.

To fully ensure that the waiting DAG and the external DAG do not run
Expand All @@ -51,28 +86,39 @@ def wait_for_external_dag(external_dag_id: str, task_id: str | None = None):
if not task_id:
task_id = f"wait_for_{external_dag_id}"

return ExternalTaskSensor(
sensor = ExternalTaskSensor(
task_id=task_id,
poke_interval=REFRESH_POKE_INTERVAL,
external_dag_id=external_dag_id,
# Wait for the whole DAG, not just a part of it
external_task_id=None,
check_existence=False,
execution_date_fn=lambda _: get_most_recent_dag_run(external_dag_id),
execution_date_fn=lambda _: _get_most_recent_dag_run(external_dag_id),
mode="reschedule",
# Any "finished" state is sufficient for us to continue
allowed_states=[State.SUCCESS, State.FAILED],
# execution_timeout for the task does not include time that the sensor
# was up for reschedule but not actually running. `timeout` does
timeout=timeout,
)

sensor.execute(context)


@task_group(group_id="wait_for_external_dags")
def wait_for_external_dags(external_dag_ids: list[str]):
@provide_session
def wait_for_external_dags_with_tag(
tag: str, excluded_dag_ids: list[str] = None, session=None
):
"""
Wait for all DAGs with the given external DAG ids to no longer be
in a running state before continuing.
Wait until all DAGs with the given `tag`, excluding those identified by the
`excluded_dag_ids`, are no longer in the running state before continuing.
"""
for dag_id in external_dag_ids:
wait_for_external_dag(dag_id)
external_dag_ids = get_dags_with_concurrency_tag.override(
task_id=f"get_dags_in_{tag}_group"
)(tag=tag, excluded_dag_ids=excluded_dag_ids or [], session=session)

wait_for_external_dag.expand(external_dag_id=external_dag_ids)


@task(retries=0)
Expand All @@ -81,18 +127,35 @@ def prevent_concurrency_with_dag(external_dag_id: str, **context):
Prevent concurrency with the given external DAG, by failing
immediately if that DAG is running.
"""

wait_for_dag = wait_for_external_dag(
external_dag_id=external_dag_id,
task_id=f"check_for_running_{external_dag_id}",
)
wait_for_dag.timeout = 0
try:
wait_for_dag.execute(context)
wait_for_external_dag.function(
external_dag_id=external_dag_id,
task_id=f"check_for_running_{external_dag_id}",
timeout=0,
**context,
)
except AirflowSensorTimeout:
raise ValueError(f"Concurrency check with {external_dag_id} failed.")


@task_group(group_id="prevent_concurrency_with_dags")
@provide_session
def prevent_concurrency_with_dags_with_tag(
tag: str, excluded_dag_ids: list[str] = None, session=None
):
"""
Prevent concurrency with any DAGs that have the given `tag`, excluding
those identified by the `excluded_dag_ids`. Concurrency is prevented by
failing the task immediately if any of the tagged DAGs are in the running
state.
"""
external_dag_ids = get_dags_with_concurrency_tag.override(
task_id=f"get_dags_in_{tag}_group"
)(tag=tag, excluded_dag_ids=excluded_dag_ids or [], session=session)

prevent_concurrency_with_dag.expand(external_dag_id=external_dag_ids)


@task(retries=0)
def is_concurrent_with_any(external_dag_ids: list[str], **context):
"""
Expand All @@ -109,12 +172,3 @@ def is_concurrent_with_any(external_dag_ids: list[str], **context):

# Explicit return None to clarify expectations
return None


@task_group(group_id="prevent_concurrency")
def prevent_concurrency_with_dags(external_dag_ids: list[str]):
"""Fail immediately if any of the given external dags are in progress."""
for dag_id in external_dag_ids:
prevent_concurrency_with_dag.override(
task_id=f"prevent_concurrency_with_{dag_id}"
)(dag_id)
28 changes: 12 additions & 16 deletions catalog/dags/data_refresh/create_filtered_index_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@

There are two mechanisms that prevent this from happening:

1. The filtered index creation DAGs are not allowed to run if a data refresh
for the media type is already running.
1. The filtered index creation DAGs fail immediately if any of the DAGs that are
tagged as part of the `production-es-concurrency` group (including the data
refreshes) are currently running.
2. The data refresh DAGs will wait for any pre-existing filtered index creation
DAG runs for the media type to finish before continuing.

Expand All @@ -56,15 +57,13 @@
from airflow import DAG
from airflow.models.param import Param

from common.constants import DAG_DEFAULT_ARGS, PRODUCTION
from common.sensors.utils import prevent_concurrency_with_dags
from common.constants import DAG_DEFAULT_ARGS
from common.sensors.constants import PRODUCTION_ES_CONCURRENCY_TAG
from common.sensors.utils import prevent_concurrency_with_dags_with_tag
from data_refresh.create_filtered_index import (
create_filtered_index_creation_task_groups,
)
from data_refresh.data_refresh_types import DATA_REFRESH_CONFIGS, DataRefresh
from elasticsearch_cluster.create_new_es_index.create_new_es_index_types import (
CREATE_NEW_INDEX_CONFIGS,
)


# Note: We can't use the TaskFlow `@dag` DAG factory decorator
Expand All @@ -88,7 +87,7 @@ def create_filtered_index_creation_dag(data_refresh: DataRefresh):
default_args=DAG_DEFAULT_ARGS,
schedule=None,
start_date=datetime(2023, 4, 1),
tags=["data_refresh"],
tags=["data_refresh", PRODUCTION_ES_CONCURRENCY_TAG],
max_active_runs=1,
catchup=False,
doc_md=__doc__,
Expand Down Expand Up @@ -117,14 +116,11 @@ def create_filtered_index_creation_dag(data_refresh: DataRefresh):
},
render_template_as_native_obj=True,
) as dag:
# Immediately fail if the associated data refresh is running, or the
# create_new_production_es_index DAG is running. This prevents multiple
# DAGs from reindexing from a single production index simultaneously.
prevent_concurrency = prevent_concurrency_with_dags(
external_dag_ids=[
data_refresh.dag_id,
CREATE_NEW_INDEX_CONFIGS[PRODUCTION].dag_id,
]
# Immediately fail if any DAG that operates on the production elasticsearch
# cluster is running. This prevents multiple DAGs from reindexing from a
# single production index simultaneously.
prevent_concurrency = prevent_concurrency_with_dags_with_tag(
tag=PRODUCTION_ES_CONCURRENCY_TAG,
)

# Once the concurrency check has passed, actually create the filtered
Expand Down
3 changes: 2 additions & 1 deletion catalog/dags/data_refresh/dag_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
OPENLEDGER_API_CONN_ID,
XCOM_PULL_TEMPLATE,
)
from common.sensors.constants import PRODUCTION_ES_CONCURRENCY_TAG
from common.sql import PGExecuteQueryOperator, single_value
from data_refresh.data_refresh_task_factory import create_data_refresh_task_group
from data_refresh.data_refresh_types import DATA_REFRESH_CONFIGS, DataRefresh
Expand Down Expand Up @@ -70,7 +71,7 @@ def create_data_refresh_dag(data_refresh: DataRefresh, external_dag_ids: Sequenc
max_active_runs=1,
catchup=False,
doc_md=__doc__,
tags=["data_refresh"],
tags=["data_refresh", PRODUCTION_ES_CONCURRENCY_TAG],
)

with dag:
Expand Down
20 changes: 10 additions & 10 deletions catalog/dags/data_refresh/data_refresh_task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,14 @@
from airflow.utils.trigger_rule import TriggerRule

from common import cloudwatch, ingestion_server
from common.constants import PRODUCTION, XCOM_PULL_TEMPLATE
from common.constants import XCOM_PULL_TEMPLATE
from common.sensors.constants import PRODUCTION_ES_CONCURRENCY_TAG
from common.sensors.single_run_external_dags_sensor import SingleRunExternalDAGsSensor
from common.sensors.utils import wait_for_external_dags
from common.sensors.utils import wait_for_external_dags_with_tag
from data_refresh.create_filtered_index import (
create_filtered_index_creation_task_groups,
)
from data_refresh.data_refresh_types import DataRefresh
from elasticsearch_cluster.create_new_es_index.create_new_es_index_types import (
CREATE_NEW_INDEX_CONFIGS,
)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -123,11 +121,13 @@ def create_data_refresh_task_group(
# Realistically the data refresh is too slow to beat the index creation process,
# even if it was triggered immediately after one of these DAGs; however, it is
# always safer to avoid the possibility of the race condition altogether.
wait_for_es_dags = wait_for_external_dags.override(group_id="wait_for_es_dags")(
external_dag_ids=[
data_refresh.filtered_index_dag_id,
CREATE_NEW_INDEX_CONFIGS[PRODUCTION].dag_id,
]
wait_for_es_dags = wait_for_external_dags_with_tag.override(
group_id="wait_for_es_dags"
)(
tag=PRODUCTION_ES_CONCURRENCY_TAG,
# Exclude the other data refresh DAG ids, as waiting on these was handled in
# the previous task.
excluded_dag_ids=external_dag_ids,
)
tasks.append([wait_for_data_refresh, wait_for_es_dags])

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
# Update the staging database
# Staging Database Restore DAG

This DAG is responsible for updating the staging database using the most recent
snapshot of the production database.
Expand Down Expand Up @@ -35,7 +35,8 @@
DAG_DEFAULT_ARGS,
POSTGRES_API_STAGING_CONN_ID,
)
from common.sensors.utils import wait_for_external_dag
from common.sensors.constants import STAGING_DB_CONCURRENCY_TAG
from common.sensors.utils import wait_for_external_dags_with_tag
from common.sql import PGExecuteQueryOperator
from database.staging_database_restore import constants
from database.staging_database_restore.staging_database_restore import (
Expand All @@ -48,9 +49,6 @@
restore_staging_from_snapshot,
skip_restore,
)
from elasticsearch_cluster.recreate_staging_index.recreate_full_staging_index import (
DAG_ID as RECREATE_STAGING_INDEX_DAG_ID,
)


log = logging.getLogger(__name__)
Expand All @@ -60,7 +58,7 @@
dag_id=constants.DAG_ID,
schedule="@monthly",
start_date=datetime(2023, 5, 1),
tags=["database"],
tags=["database", STAGING_DB_CONCURRENCY_TAG],
max_active_runs=1,
dagrun_timeout=timedelta(days=1),
catchup=False,
Expand All @@ -76,9 +74,10 @@
def restore_staging_database():
# If the `recreate_full_staging_index` DAG was manually triggered prior
# to the database restoration starting, we should wait for it to
# finish.
wait_for_recreate_full_staging_index = wait_for_external_dag(
external_dag_id=RECREATE_STAGING_INDEX_DAG_ID,
# finish. It is not necessary to wait on any of the other ES DAGs as
# they do not directly affect the database.
wait_for_recreate_full_staging_index = wait_for_external_dags_with_tag(
tag=STAGING_DB_CONCURRENCY_TAG
)
should_skip = skip_restore()
latest_snapshot = get_latest_prod_snapshot()
Expand Down
Loading
Loading