Skip to content

Commit

Permalink
Automatically exclude the running DAG from the ids to check
Browse files Browse the repository at this point in the history
  • Loading branch information
stacimc committed Mar 14, 2024
1 parent 283e3f6 commit 811c1a0
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 47 deletions.
91 changes: 57 additions & 34 deletions catalog/dags/common/sensors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from common.constants import REFRESH_POKE_INTERVAL


THREE_DAYS = 60 * 60 * 24 * 3


def _get_most_recent_dag_run(dag_id) -> list[datetime] | datetime:
"""
Retrieve the most recent DAG run's execution date.
Expand All @@ -36,21 +39,40 @@ def _get_most_recent_dag_run(dag_id) -> list[datetime] | datetime:
return []


def _get_dags_with_tag(tag: str, excluded_dag_ids: list[str], session=None):
"""Get a list of DAG ids with the given tag, optionally excluding certain ids."""
if not excluded_dag_ids:
excluded_dag_ids = []

@task
def get_dags_with_concurrency_tag(
tag: str, excluded_dag_ids: list[str], session=None, **context
):
"""
Get a list of DAG ids with the given tag. The id of the running DAG is excluded,
as well as any ids in the `excluded_dag_ids` list.
"""
dags = session.query(DagModel).filter(DagModel.tags.any(DagTag.name == tag)).all()

# Return just the ids, excluding excluded_dag_ids
ids = [dag.dag_id for dag in dags if dag.dag_id not in excluded_dag_ids]
return ids


def wait_for_external_dag(external_dag_id: str, task_id: str | None = None):
dag_ids = [dag.dag_id for dag in dags]

running_dag_id = context["dag"].dag_id
if running_dag_id not in dag_ids:
raise ValueError(
f"The `{running_dag_id}` DAG tried preventing concurrency with the `{tag}`,"
" tag, but does not have the tag itself. To ensure that other DAGs with this"
f" tag will also avoid running concurrently with `{running_dag_id}`, it must"
f"have the `{tag}` tag applied."
)

# Return just the ids of DAGs to prevent concurrency with. This excludes the running dag id,
# and any supplied `excluded_dag_ids`
return [id for id in dag_ids if id not in [*excluded_dag_ids, running_dag_id]]


@task
def wait_for_external_dag(
external_dag_id: str,
task_id: str | None = None,
timeout: int | None = THREE_DAYS,
**context,
):
"""
Return a Sensor task which will wait if the given external DAG is
Execute a Sensor task which will wait if the given external DAG is
running.
To fully ensure that the waiting DAG and the external DAG do not run
Expand All @@ -64,7 +86,7 @@ def wait_for_external_dag(external_dag_id: str, task_id: str | None = None):
if not task_id:
task_id = f"wait_for_{external_dag_id}"

return ExternalTaskSensor(
sensor = ExternalTaskSensor(
task_id=task_id,
poke_interval=REFRESH_POKE_INTERVAL,
external_dag_id=external_dag_id,
Expand All @@ -75,24 +97,28 @@ def wait_for_external_dag(external_dag_id: str, task_id: str | None = None):
mode="reschedule",
# Any "finished" state is sufficient for us to continue
allowed_states=[State.SUCCESS, State.FAILED],
# execution_timeout for the task does not include time that the sensor
# was up for reschedule but not actually running. `timeout` does
timeout=timeout,
)

sensor.execute(context)


@task_group(group_id="wait_for_external_dags")
@provide_session
def wait_for_external_dags_with_tag(
tag: str, excluded_dag_ids: list[str], session=None, **context
tag: str, excluded_dag_ids: list[str] = None, session=None
):
"""
Wait until all DAGs with the given `tag`, excluding those identified by the
`excluded_dag_ids`, are no longer in the running state before continuing.
"""
external_dag_ids = _get_dags_with_tag(
tag=tag, excluded_dag_ids=excluded_dag_ids, session=session
)
external_dag_ids = get_dags_with_concurrency_tag.override(
task_id=f"get_dags_in_{tag}_group"
)(tag=tag, excluded_dag_ids=excluded_dag_ids or [], session=session)

for dag_id in external_dag_ids:
wait_for_external_dag(dag_id)
wait_for_external_dag.expand(external_dag_id=external_dag_ids)


@task(retries=0)
Expand All @@ -101,36 +127,33 @@ def prevent_concurrency_with_dag(external_dag_id: str, **context):
Prevent concurrency with the given external DAG, by failing
immediately if that DAG is running.
"""
wait_for_dag = wait_for_external_dag(
external_dag_id=external_dag_id,
task_id=f"check_for_running_{external_dag_id}",
)
wait_for_dag.timeout = 0
try:
wait_for_dag.execute(context)
wait_for_external_dag.function(
external_dag_id=external_dag_id,
task_id=f"check_for_running_{external_dag_id}",
timeout=0,
**context,
)
except AirflowSensorTimeout:
raise ValueError(f"Concurrency check with {external_dag_id} failed.")


@task_group(group_id="prevent_concurrency_with_dags")
@provide_session
def prevent_concurrency_with_dags_with_tag(
tag: str, excluded_dag_ids: list[str], session=None, **context
tag: str, excluded_dag_ids: list[str] = None, session=None
):
"""
Prevent concurrency with any DAGs that have the given `tag`, excluding
those identified by the `excluded_dag_ids`. Concurrency is prevented by
failing the task immediately if any of the tagged DAGs are in the running
state.
"""
external_dag_ids = _get_dags_with_tag(
tag=tag, excluded_dag_ids=excluded_dag_ids, session=session
)
external_dag_ids = get_dags_with_concurrency_tag.override(
task_id=f"get_dags_in_{tag}_group"
)(tag=tag, excluded_dag_ids=excluded_dag_ids or [], session=session)

for external_dag_id in external_dag_ids:
prevent_concurrency_with_dag.override(
task_id=f"prevent_concurrency_with_{external_dag_id}"
)(external_dag_id)
prevent_concurrency_with_dag.expand(external_dag_id=external_dag_ids)


@task(retries=0)
Expand Down
1 change: 0 additions & 1 deletion catalog/dags/data_refresh/create_filtered_index_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def create_filtered_index_creation_dag(data_refresh: DataRefresh):
# single production index simultaneously.
prevent_concurrency = prevent_concurrency_with_dags_with_tag(
tag=PRODUCTION_ES_CONCURRENCY_TAG,
excluded_dag_ids=[data_refresh.filtered_index_dag_id],
)

# Once the concurrency check has passed, actually create the filtered
Expand Down
6 changes: 3 additions & 3 deletions catalog/dags/data_refresh/data_refresh_task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def create_data_refresh_task_group(
group_id="wait_for_es_dags"
)(
tag=PRODUCTION_ES_CONCURRENCY_TAG,
# Exclude the current DAG id, as well as all other data refresh DAG ids (these
# are waited on in the previous task)
excluded_dag_ids=[*external_dag_ids, data_refresh.dag_id],
# Exclude the other data refresh DAG ids, as waiting on these was handled in
# the previous task.
excluded_dag_ids=external_dag_ids,
)
tasks.append([wait_for_data_refresh, wait_for_es_dags])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def restore_staging_database():
# Wait for any DAGs that operate on the staging elasticsearch cluster
wait_for_recreate_full_staging_index = wait_for_external_dags_with_tag(
tag=STAGING_ES_CONCURRENCY_TAG,
excluded_dag_ids=[constants.DAG_ID],
)
should_skip = skip_restore()
latest_snapshot = get_latest_prod_snapshot()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def create_new_es_index_dag(config: CreateNewIndex):
# Fail early if any other DAG that operates on the relevant elasticsearch cluster
# is running
prevent_concurrency = prevent_concurrency_with_dags_with_tag(
tag=config.concurrency_tag, excluded_dag_ids=[config.dag_id]
tag=config.concurrency_tag,
)

es_host = es.get_es_host(environment=config.environment)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
def create_proportional_by_source_staging_index():
# Fail early if any conflicting DAGs are running
prevent_concurrency = prevent_concurrency_with_dags_with_tag(
tag=STAGING_ES_CONCURRENCY_TAG, excluded_dag_ids=[DAG_ID]
tag=STAGING_ES_CONCURRENCY_TAG,
)

es_host = es.get_es_host(environment=STAGING)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
from common.sensors.utils import prevent_concurrency_with_dags_with_tag


def point_es_alias_dag(environment: str, dag_id: str):
dag = DAG(
dag_id=dag_id,
def point_es_alias_dag(environment: str):
dag = DAG(
dag_id=f"point_{environment}_alias",
default_args=DAG_DEFAULT_ARGS,
schedule=None,
start_date=datetime(2024, 1, 31),
Expand Down Expand Up @@ -79,7 +79,7 @@ def point_es_alias_dag(environment: str, dag_id: str):
# Fail early if any other DAG that operates on the elasticsearch cluster for
# this environment is running
prevent_concurrency = prevent_concurrency_with_dags_with_tag(
tag=ES_CONCURRENCY_TAGS[environment], excluded_dag_ids=[dag_id]
tag=ES_CONCURRENCY_TAGS[environment],
)

es_host = es.get_es_host(environment=environment)
Expand All @@ -104,4 +104,4 @@ def point_es_alias_dag(environment: str, dag_id: str):


for environment in ENVIRONMENTS:
point_es_alias_dag(environment, f"point_{environment}_alias")
point_es_alias_dag(environment)
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def recreate_full_staging_index():
# Fail early if any other DAG that operates on the staging elasticsearch cluster
# is running
prevent_concurrency = prevent_concurrency_with_dags_with_tag(
tag=STAGING_ES_CONCURRENCY_TAG, excluded_dag_ids=[DAG_ID]
tag=STAGING_ES_CONCURRENCY_TAG,
)

target_alias = get_target_alias(
Expand Down

0 comments on commit 811c1a0

Please sign in to comment.