Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add elasticsearch concurrency tags for Airflow #3910

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 73 additions & 33 deletions catalog/dags/common/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from typing import Literal, Union

from airflow.decorators import task, task_group
from airflow.exceptions import AirflowSkipException
from airflow.models.connection import Connection
from airflow.providers.elasticsearch.hooks.elasticsearch import ElasticsearchPythonHook
from airflow.sensors.base import PokeReturnValue
from airflow.utils.trigger_rule import TriggerRule
from elasticsearch.exceptions import NotFoundError

from common.constants import REFRESH_POKE_INTERVAL

Expand Down Expand Up @@ -175,51 +176,90 @@ def refresh_index(es_host: str, index_name: str):


@task_group(group_id="point_alias")
def point_alias(index_name: str, alias: str, es_host: str):
def point_alias(
es_host: str,
target_index: str,
target_alias: str,
should_delete_old_index: bool = False,
):
"""
Point the target alias to the given index. If the alias is already being
used by one or more indices, it will first be removed from all of them.
"""
used by another index, it will be removed from this index first. Optionally,
that index may also be automatically deleted.

@task.branch
def check_if_alias_exists(alias: str, es_host: str):
"""Check if the alias already exists."""
es_conn = ElasticsearchPythonHook(hosts=[es_host]).get_conn
return (
"point_alias.remove_existing_alias"
if es_conn.indices.exists_alias(name=alias)
else "point_alias.point_new_alias"
)
Required Arguments:

es_host: Connection string for elasticsearch
target_index: Str identifier for the target index. May be either the index name
or an existing alias.
target_alias: The new alias to be applied to the target index

Optional Arguments:

should_delete_old_index: If True, the index previously pointed to by the target
alias (if one exists) will be deleted.
"""

@task
def remove_existing_alias(alias: str, es_host: str):
"""Remove the given alias from any indices to which it points."""
def get_existing_index(es_host: str, target_alias: str):
"""Get the index to which the target alias currently points, if it exists."""
if not target_alias:
raise AirflowSkipException("No target alias was provided.")

es_conn = ElasticsearchPythonHook(hosts=[es_host]).get_conn
response = es_conn.indices.delete_alias(
name=alias,
# Remove the alias from _all_ indices to which it currently
# applies
index="_all",
)
return response.get("acknowledged")

try:
response = es_conn.indices.get_alias(name=target_alias)
if len(response) > 1:
raise ValueError(
"Expected at most one existing index with target alias"
f"{target_alias}, but {len(response)} were found."
)
return list(response.keys())[0]
except NotFoundError:
logger.info(f"Target alias {target_alias} does not exist.")
return None

@task
def point_new_alias(
es_host: str,
index_name: str,
alias: str,
target_index: str,
existing_index: str,
target_alias: str,
):
"""
Remove the target_alias from the existing index to which it applies, if
applicable, and point it to the target_index in one atomic operation.
"""
if not target_alias:
raise AirflowSkipException("No target alias was provided.")

es_conn = ElasticsearchPythonHook(hosts=[es_host]).get_conn
response = es_conn.indices.put_alias(index=index_name, name=alias)

actions = []
if existing_index:
actions.append({"remove": {"index": existing_index, "alias": target_alias}})
actions.append({"add": {"index": target_index, "alias": target_alias}})
logger.info(f"Applying actions: {actions}")

response = es_conn.indices.update_aliases(body={"actions": actions})
return response.get("acknowledged")

exists_alias = check_if_alias_exists(alias, es_host)
remove_alias = remove_existing_alias(alias, es_host)
@task
def delete_old_index(es_host: str, index_name: str, should_delete_old_index: bool):
if not should_delete_old_index:
raise AirflowSkipException("`should_delete_old_index` is set to `False`.")
if not index_name:
raise AirflowSkipException("No applicable index to delete.")

es_conn = ElasticsearchPythonHook(hosts=[es_host]).get_conn
response = es_conn.indices.delete(index=index_name)
return response.get("acknowledged")

existing_index = get_existing_index(es_host, target_alias)

point_alias = point_new_alias(es_host, target_index, existing_index, target_alias)

point_alias = point_new_alias.override(
# The remove_alias task may be skipped.
trigger_rule=TriggerRule.NONE_FAILED,
)(es_host, index_name, alias)
delete_index = delete_old_index(es_host, existing_index, should_delete_old_index)

exists_alias >> [remove_alias, point_alias]
remove_alias >> point_alias
existing_index >> point_alias >> delete_index
11 changes: 11 additions & 0 deletions catalog/dags/common/sensors/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from common.constants import PRODUCTION, STAGING


# DagTags used to establish a concurrency group for each environment
PRODUCTION_ES_CONCURRENCY_TAG = "production_es_concurrency"
STAGING_ES_CONCURRENCY_TAG = "staging_es_concurrency"

ES_CONCURRENCY_TAGS = {
PRODUCTION: PRODUCTION_ES_CONCURRENCY_TAG,
STAGING: STAGING_ES_CONCURRENCY_TAG,
}
108 changes: 81 additions & 27 deletions catalog/dags/common/sensors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@

from airflow.decorators import task, task_group
from airflow.exceptions import AirflowSensorTimeout
from airflow.models import DagRun
from airflow.models import DagModel, DagRun, DagTag
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.utils.session import provide_session
from airflow.utils.state import State

from common.constants import REFRESH_POKE_INTERVAL


def get_most_recent_dag_run(dag_id) -> list[datetime] | datetime:
THREE_DAYS = 60 * 60 * 24 * 3


def _get_most_recent_dag_run(dag_id) -> list[datetime] | datetime:
"""
Retrieve the most recent DAG run's execution date.

Expand All @@ -35,9 +39,40 @@ def get_most_recent_dag_run(dag_id) -> list[datetime] | datetime:
return []


def wait_for_external_dag(external_dag_id: str, task_id: str | None = None):
@task
def get_dags_with_concurrency_tag(
tag: str, excluded_dag_ids: list[str], session=None, **context
):
"""
Get a list of DAG ids with the given tag. The id of the running DAG is excluded,
as well as any ids in the `excluded_dag_ids` list.
"""
Return a Sensor task which will wait if the given external DAG is
dags = session.query(DagModel).filter(DagModel.tags.any(DagTag.name == tag)).all()
dag_ids = [dag.dag_id for dag in dags]

running_dag_id = context["dag"].dag_id
if running_dag_id not in dag_ids:
raise ValueError(
f"The `{running_dag_id}` DAG tried preventing concurrency with the `{tag}`,"
" tag, but does not have the tag itself. To ensure that other DAGs with this"
f" tag will also avoid running concurrently with `{running_dag_id}`, it must"
f"have the `{tag}` tag applied."
)

# Return just the ids of DAGs to prevent concurrency with. This excludes the running dag id,
# and any supplied `excluded_dag_ids`
return [id for id in dag_ids if id not in [*excluded_dag_ids, running_dag_id]]


@task
def wait_for_external_dag(
external_dag_id: str,
task_id: str | None = None,
timeout: int | None = THREE_DAYS,
**context,
):
"""
Execute a Sensor task which will wait if the given external DAG is
running.

To fully ensure that the waiting DAG and the external DAG do not run
Expand All @@ -51,28 +86,39 @@ def wait_for_external_dag(external_dag_id: str, task_id: str | None = None):
if not task_id:
task_id = f"wait_for_{external_dag_id}"

return ExternalTaskSensor(
sensor = ExternalTaskSensor(
task_id=task_id,
poke_interval=REFRESH_POKE_INTERVAL,
external_dag_id=external_dag_id,
# Wait for the whole DAG, not just a part of it
external_task_id=None,
check_existence=False,
execution_date_fn=lambda _: get_most_recent_dag_run(external_dag_id),
execution_date_fn=lambda _: _get_most_recent_dag_run(external_dag_id),
mode="reschedule",
# Any "finished" state is sufficient for us to continue
allowed_states=[State.SUCCESS, State.FAILED],
# execution_timeout for the task does not include time that the sensor
# was up for reschedule but not actually running. `timeout` does
timeout=timeout,
)

sensor.execute(context)


@task_group(group_id="wait_for_external_dags")
def wait_for_external_dags(external_dag_ids: list[str]):
@provide_session
def wait_for_external_dags_with_tag(
tag: str, excluded_dag_ids: list[str] = None, session=None
):
"""
Wait for all DAGs with the given external DAG ids to no longer be
in a running state before continuing.
Wait until all DAGs with the given `tag`, excluding those identified by the
`excluded_dag_ids`, are no longer in the running state before continuing.
"""
for dag_id in external_dag_ids:
wait_for_external_dag(dag_id)
external_dag_ids = get_dags_with_concurrency_tag.override(
task_id=f"get_dags_in_{tag}_group"
)(tag=tag, excluded_dag_ids=excluded_dag_ids or [], session=session)

wait_for_external_dag.expand(external_dag_id=external_dag_ids)


@task(retries=0)
Expand All @@ -81,18 +127,35 @@ def prevent_concurrency_with_dag(external_dag_id: str, **context):
Prevent concurrency with the given external DAG, by failing
immediately if that DAG is running.
"""

wait_for_dag = wait_for_external_dag(
external_dag_id=external_dag_id,
task_id=f"check_for_running_{external_dag_id}",
)
wait_for_dag.timeout = 0
try:
wait_for_dag.execute(context)
wait_for_external_dag.function(
external_dag_id=external_dag_id,
task_id=f"check_for_running_{external_dag_id}",
timeout=0,
**context,
)
except AirflowSensorTimeout:
raise ValueError(f"Concurrency check with {external_dag_id} failed.")


@task_group(group_id="prevent_concurrency_with_dags")
@provide_session
def prevent_concurrency_with_dags_with_tag(
tag: str, excluded_dag_ids: list[str] = None, session=None
):
"""
Prevent concurrency with any DAGs that have the given `tag`, excluding
those identified by the `excluded_dag_ids`. Concurrency is prevented by
failing the task immediately if any of the tagged DAGs are in the running
state.
"""
external_dag_ids = get_dags_with_concurrency_tag.override(
task_id=f"get_dags_in_{tag}_group"
)(tag=tag, excluded_dag_ids=excluded_dag_ids or [], session=session)

prevent_concurrency_with_dag.expand(external_dag_id=external_dag_ids)


@task(retries=0)
def is_concurrent_with_any(external_dag_ids: list[str], **context):
"""
Expand All @@ -109,12 +172,3 @@ def is_concurrent_with_any(external_dag_ids: list[str], **context):

# Explicit return None to clarify expectations
return None


@task_group(group_id="prevent_concurrency")
def prevent_concurrency_with_dags(external_dag_ids: list[str]):
"""Fail immediately if any of the given external dags are in progress."""
for dag_id in external_dag_ids:
prevent_concurrency_with_dag.override(
task_id=f"prevent_concurrency_with_{dag_id}"
)(dag_id)
28 changes: 12 additions & 16 deletions catalog/dags/data_refresh/create_filtered_index_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@

There are two mechanisms that prevent this from happening:

1. The filtered index creation DAGs are not allowed to run if a data refresh
for the media type is already running.
1. The filtered index creation DAGs fail immediately if any of the DAGs that are
tagged as prt of the `production-es-concurrency` group (including the data
refreshes) are currently running.
2. The data refresh DAGs will wait for any pre-existing filtered index creation
DAG runs for the media type to finish before continuing.

Expand All @@ -56,15 +57,13 @@
from airflow import DAG
from airflow.models.param import Param

from common.constants import DAG_DEFAULT_ARGS, PRODUCTION
from common.sensors.utils import prevent_concurrency_with_dags
from common.constants import DAG_DEFAULT_ARGS
from common.sensors.constants import PRODUCTION_ES_CONCURRENCY_TAG
from common.sensors.utils import prevent_concurrency_with_dags_with_tag
from data_refresh.create_filtered_index import (
create_filtered_index_creation_task_groups,
)
from data_refresh.data_refresh_types import DATA_REFRESH_CONFIGS, DataRefresh
from elasticsearch_cluster.create_new_es_index.create_new_es_index_types import (
CREATE_NEW_INDEX_CONFIGS,
)


# Note: We can't use the TaskFlow `@dag` DAG factory decorator
Expand All @@ -88,7 +87,7 @@ def create_filtered_index_creation_dag(data_refresh: DataRefresh):
default_args=DAG_DEFAULT_ARGS,
schedule=None,
start_date=datetime(2023, 4, 1),
tags=["data_refresh"],
tags=["data_refresh", PRODUCTION_ES_CONCURRENCY_TAG],
max_active_runs=1,
catchup=False,
doc_md=__doc__,
Expand Down Expand Up @@ -117,14 +116,11 @@ def create_filtered_index_creation_dag(data_refresh: DataRefresh):
},
render_template_as_native_obj=True,
) as dag:
# Immediately fail if the associated data refresh is running, or the
# create_new_production_es_index DAG is running. This prevents multiple
# DAGs from reindexing from a single production index simultaneously.
prevent_concurrency = prevent_concurrency_with_dags(
external_dag_ids=[
data_refresh.dag_id,
CREATE_NEW_INDEX_CONFIGS[PRODUCTION].dag_id,
]
# Immediately fail if any DAG that operates on the production elasticsearch
# cluster is running. This prevents multiple DAGs from reindexing from a
# single production index simultaneously.
prevent_concurrency = prevent_concurrency_with_dags_with_tag(
tag=PRODUCTION_ES_CONCURRENCY_TAG,
)

# Once the concurrency check has passed, actually create the filtered
Expand Down
3 changes: 2 additions & 1 deletion catalog/dags/data_refresh/dag_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
OPENLEDGER_API_CONN_ID,
XCOM_PULL_TEMPLATE,
)
from common.sensors.constants import PRODUCTION_ES_CONCURRENCY_TAG
from common.sql import PGExecuteQueryOperator, single_value
from data_refresh.data_refresh_task_factory import create_data_refresh_task_group
from data_refresh.data_refresh_types import DATA_REFRESH_CONFIGS, DataRefresh
Expand Down Expand Up @@ -70,7 +71,7 @@ def create_data_refresh_dag(data_refresh: DataRefresh, external_dag_ids: Sequenc
max_active_runs=1,
catchup=False,
doc_md=__doc__,
tags=["data_refresh"],
tags=["data_refresh", PRODUCTION_ES_CONCURRENCY_TAG],
)

with dag:
Expand Down
Loading
Loading