Skip to content

Commit

Permalink
Add DAG for creating staging indices (#3232)
Browse files Browse the repository at this point in the history
* Optionally override ES environment in ingestion server

* Revert "Optionally override ES environment in ingestion server"

This reverts commit 4840acf.

* Add staging ingestion server connection

* Make ingestion server utilities accept optional http_conn_id

* Make sharedtask for notifying slack

* Consider ingestion server task active if it has active workers

This is necessary to fix a bug where an ingestion server task is considered to be in the "errored" state by the TaskStatus, when it schedules some indexer workers and then completes (because in this state, the task is no longer alive but progress has not yet reached 100%). By checking whether there are active workers associated with the task id, we can correctly determine whether the task is actually in an errored state.

* Add recreate_full_staging_index dags

* Add Dag docs

* Test DAG parsing

* Remove unused constant from previous implementation

* Update DAG tags

* Respect the data_refresh_limit in reindex task as well as ingest_upstream

* Clarify defaults for data_refresh_limit

* Simplify the params by making target_alias customizable

* Make media_type a dag param, refactor with TaskFlow

* Prevent staging DB restore and index creation from simultaneous runs

* Fix imports to prevent errors when filling DagBag

* Update dag docs

* Avoid module name conflicts
  • Loading branch information
stacimc authored Dec 6, 2023
1 parent ab4c429 commit ee77ef5
Show file tree
Hide file tree
Showing 19 changed files with 457 additions and 71 deletions.
17 changes: 11 additions & 6 deletions catalog/dags/common/ingestion_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,12 @@ def generate_index_suffix(default_suffix: str | None = None) -> str:
return default_suffix or uuid.uuid4().hex


def get_current_index(target_alias: str) -> SimpleHttpOperator:
def get_current_index(
target_alias: str, http_conn_id: str = "data_refresh"
) -> SimpleHttpOperator:
return SimpleHttpOperator(
task_id="get_current_index",
http_conn_id="data_refresh",
http_conn_id=http_conn_id,
endpoint=f"stat/{target_alias}",
method="GET",
response_check=lambda response: response.status_code == 200,
Expand All @@ -108,6 +110,7 @@ def trigger_task(
action: str,
model: str,
data: dict | None = None,
http_conn_id: str = "data_refresh",
) -> SimpleHttpOperator:
data = {
**(data or {}),
Expand All @@ -116,7 +119,7 @@ def trigger_task(
}
return SimpleHttpOperator(
task_id=f"trigger_{action.lower()}",
http_conn_id="data_refresh",
http_conn_id=http_conn_id,
endpoint="task",
data=data,
response_check=lambda response: response.status_code == 202,
Expand All @@ -129,10 +132,11 @@ def wait_for_task(
task_trigger: SimpleHttpOperator,
timeout: timedelta,
poke_interval: int = REFRESH_POKE_INTERVAL,
http_conn_id: str = "data_refresh",
) -> HttpSensor:
return HttpSensor(
task_id=f"wait_for_{action.lower()}",
http_conn_id="data_refresh",
http_conn_id=http_conn_id,
endpoint=XCOM_PULL_TEMPLATE.format(task_trigger.task_id, "return_value"),
method="GET",
response_check=response_check_wait_for_completion,
Expand All @@ -148,9 +152,10 @@ def trigger_and_wait_for_task(
timeout: timedelta,
data: dict | None = None,
poke_interval: int = REFRESH_POKE_INTERVAL,
http_conn_id: str = "data_refresh",
) -> tuple[SimpleHttpOperator, HttpSensor]:
trigger = trigger_task(action, model, data)
waiter = wait_for_task(action, trigger, timeout, poke_interval)
trigger = trigger_task(action, model, data, http_conn_id)
waiter = wait_for_task(action, trigger, timeout, poke_interval, http_conn_id)
trigger >> waiter
return trigger, waiter

Expand Down
16 changes: 16 additions & 0 deletions catalog/dags/common/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from os.path import basename
from typing import Any

from airflow.decorators import task
from airflow.exceptions import AirflowNotFoundException
from airflow.models import Variable
from airflow.providers.http.hooks.http import HttpHook
Expand Down Expand Up @@ -404,3 +405,18 @@ def on_failure_callback(context: dict) -> None:
task_id=task_id,
username="Airflow DAG Failure",
)


@task
def notify_slack(
text: str,
dag_id: str,
username: str = "Airflow Notification",
icon_emoji: str = ":airflow:",
) -> None:
send_message(
text,
username=username,
icon_emoji=icon_emoji,
dag_id=dag_id,
)
2 changes: 1 addition & 1 deletion catalog/dags/data_refresh/create_filtered_index_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,5 +161,5 @@ def prevent_concurrency_with_data_refresh(**context):
return dag


for data_refresh in DATA_REFRESH_CONFIGS:
for data_refresh in DATA_REFRESH_CONFIGS.values():
create_filtered_index_dag = create_filtered_index_creation_dag(data_refresh)
4 changes: 2 additions & 2 deletions catalog/dags/data_refresh/dag_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def create_data_refresh_dag(data_refresh: DataRefresh, external_dag_ids: Sequenc


# Generate a data refresh DAG for each DATA_REFRESH_CONFIG.
all_data_refresh_dag_ids = {refresh.dag_id for refresh in DATA_REFRESH_CONFIGS}
all_data_refresh_dag_ids = {refresh.dag_id for refresh in DATA_REFRESH_CONFIGS.values()}

for data_refresh in DATA_REFRESH_CONFIGS:
for data_refresh in DATA_REFRESH_CONFIGS.values():
# Construct a set of all data refresh DAG ids other than the current DAG
other_dag_ids = all_data_refresh_dag_ids - {data_refresh.dag_id}

Expand Down
28 changes: 7 additions & 21 deletions catalog/dags/data_refresh/data_refresh_task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,9 @@
"""
import logging
import os
import uuid
from collections.abc import Sequence

from airflow.models.baseoperator import chain
from airflow.operators.python import PythonOperator
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.utils.state import State
from airflow.utils.task_group import TaskGroup
Expand Down Expand Up @@ -150,11 +148,9 @@ def create_data_refresh_task_group(
tasks.append(get_current_index)

# Generate a UUID suffix that will be used by the newly created index.
generate_index_suffix = PythonOperator(
task_id="generate_index_suffix",
python_callable=lambda: uuid.uuid4().hex,
generate_index_suffix = ingestion_server.generate_index_suffix.override(
trigger_rule=TriggerRule.NONE_FAILED,
)
)()
tasks.append(generate_index_suffix)

# Trigger the 'ingest_upstream' task on the ingestion server and await its
Expand All @@ -166,9 +162,7 @@ def create_data_refresh_task_group(
action="ingest_upstream",
model=data_refresh.media_type,
data={
"index_suffix": XCOM_PULL_TEMPLATE.format(
generate_index_suffix.task_id, "return_value"
),
"index_suffix": generate_index_suffix,
},
timeout=data_refresh.data_refresh_timeout,
)
Expand All @@ -177,9 +171,7 @@ def create_data_refresh_task_group(
# Await healthy results from the newly created elasticsearch index.
index_readiness_check = ingestion_server.index_readiness_check(
media_type=data_refresh.media_type,
index_suffix=XCOM_PULL_TEMPLATE.format(
generate_index_suffix.task_id, "return_value"
),
index_suffix=generate_index_suffix,
timeout=data_refresh.index_readiness_timeout,
)
tasks.append(index_readiness_check)
Expand All @@ -191,14 +183,10 @@ def create_data_refresh_task_group(
promote_filtered_index,
) = create_filtered_index_creation_task_groups(
data_refresh=data_refresh,
origin_index_suffix=XCOM_PULL_TEMPLATE.format(
generate_index_suffix.task_id, "return_value"
),
origin_index_suffix=generate_index_suffix,
# Match origin and destination suffixes so we can tell which
# filtered indexes were created as part of a data refresh.
destination_index_suffix=XCOM_PULL_TEMPLATE.format(
generate_index_suffix.task_id, "return_value"
),
destination_index_suffix=generate_index_suffix,
)

# Add the task group for triggering the filtered index creation and awaiting its
Expand All @@ -216,9 +204,7 @@ def create_data_refresh_task_group(
action="promote",
model=data_refresh.media_type,
data={
"index_suffix": XCOM_PULL_TEMPLATE.format(
generate_index_suffix.task_id, "return_value"
),
"index_suffix": generate_index_suffix,
"alias": target_alias,
},
timeout=data_refresh.data_refresh_timeout,
Expand Down
10 changes: 5 additions & 5 deletions catalog/dags/data_refresh/data_refresh_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass, field
from datetime import datetime, timedelta

from common.constants import REFRESH_POKE_INTERVAL
from common.constants import AUDIO, IMAGE, REFRESH_POKE_INTERVAL


@dataclass
Expand Down Expand Up @@ -71,8 +71,8 @@ def __post_init__(self):
self.dag_id = f"{self.media_type}_data_refresh"


DATA_REFRESH_CONFIGS = [
DataRefresh(
DATA_REFRESH_CONFIGS = {
IMAGE: DataRefresh(
media_type="image",
data_refresh_timeout=timedelta(days=4),
refresh_metrics_timeout=timedelta(hours=24),
Expand All @@ -84,11 +84,11 @@ def __post_init__(self):
os.getenv("DATA_REFRESH_POKE_INTERVAL", 60 * 30)
),
),
DataRefresh(
AUDIO: DataRefresh(
media_type="audio",
data_refresh_poke_interval=int(
os.getenv("DATA_REFRESH_POKE_INTERVAL", 60 * 30)
),
filtered_index_poke_interval=int(os.getenv("DATA_REFRESH_POKE_INTERVAL", 60)),
),
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,20 @@
from airflow.decorators import dag
from airflow.providers.amazon.aws.operators.rds import RdsDeleteDbInstanceOperator
from airflow.providers.amazon.aws.sensors.rds import RdsSnapshotExistenceSensor
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule
from es.recreate_staging_index.recreate_full_staging_index import (
DAG_ID as RECREATE_STAGING_INDEX_DAG_ID,
)

from common.constants import (
AWS_RDS_CONN_ID,
DAG_DEFAULT_ARGS,
POSTGRES_API_STAGING_CONN_ID,
REFRESH_POKE_INTERVAL,
)
from common.sensors.utils import get_most_recent_dag_run
from common.sql import PGExecuteQueryOperator
from database.staging_database_restore import constants
from database.staging_database_restore.staging_database_restore import (
Expand Down Expand Up @@ -70,9 +77,26 @@
render_template_as_native_obj=True,
)
def restore_staging_database():
# If the `recreate_full_staging_index` DAG was manually triggered prior
# to the database restoration starting, we should wait for it to
# finish.
wait_for_recreate_full_staging_index = ExternalTaskSensor(
task_id="wait_for_recreate_full_staging_index",
external_dag_id=RECREATE_STAGING_INDEX_DAG_ID,
# Wait for the whole DAG, not just a part of it
external_task_id=None,
check_existence=False,
poke_interval=REFRESH_POKE_INTERVAL,
execution_date_fn=lambda _: get_most_recent_dag_run(
RECREATE_STAGING_INDEX_DAG_ID
),
mode="reschedule",
# Any "finished" state is sufficient for us to continue.
allowed_states=[State.SUCCESS, State.FAILED],
)
should_skip = skip_restore()
latest_snapshot = get_latest_prod_snapshot()
should_skip >> latest_snapshot
wait_for_recreate_full_staging_index >> should_skip >> latest_snapshot

ensure_snapshot_ready = RdsSnapshotExistenceSensor(
task_id="ensure_snapshot_ready",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from datetime import timedelta

from airflow.decorators import task, task_group
from airflow.exceptions import AirflowSensorTimeout
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.utils.state import State

from common import ingestion_server
from common.sensors.utils import get_most_recent_dag_run
from data_refresh.data_refresh_types import DATA_REFRESH_CONFIGS
from database.staging_database_restore.constants import (
DAG_ID as STAGING_DB_RESTORE_DAG_ID,
)


DAG_ID = "recreate_full_staging_index"


@task(retries=0)
def prevent_concurrency_with_staging_database_restore(**context):
wait_for_dag = ExternalTaskSensor(
task_id="check_for_running_staging_db_restore",
external_dag_id=STAGING_DB_RESTORE_DAG_ID,
# Set timeout to 0 to prevent retries. If the staging DB restoration is running,
# immediately fail the staging index creation DAG.
timeout=0,
# Wait for the whole DAG, not just a part of it
external_task_id=None,
check_existence=False,
execution_date_fn=lambda _: get_most_recent_dag_run(STAGING_DB_RESTORE_DAG_ID),
# Any "finished" state is sufficient for us to continue.
allowed_states=[State.SUCCESS, State.FAILED],
mode="reschedule",
)
try:
wait_for_dag.execute(context)
except AirflowSensorTimeout:
raise ValueError(
"Concurrency check failed. Staging index creation cannot start"
" during staging DB restoration."
)


@task
def get_target_alias(media_type: str, target_alias_override: str):
return target_alias_override or f"{media_type}-full"


@task.branch
def should_delete_index(should_delete: bool, old_index: str):
if should_delete and old_index:
# We should try to delete the old index only if the param is enabled,
# and we were able to find an index with the target_alias in the
# preceding task.
return "trigger_delete_index"
# Skip straight to notifying Slack.
return "notify_complete"


@task_group(group_id="create_index")
def create_index(media_type: str, index_suffix: str) -> None:
"""Create the new elasticsearch index on the staging cluster."""

# Get the DataRefresh config associated with this media type, in order to get
# the reindexing timeout information.
config = DATA_REFRESH_CONFIGS.get(media_type)
data_refresh_timeout = config.data_refresh_timeout if config else timedelta(days=1)

ingestion_server.trigger_and_wait_for_task(
action="REINDEX",
model=media_type,
data={"index_suffix": index_suffix},
timeout=data_refresh_timeout,
http_conn_id="staging_data_refresh",
)


@task_group(group_id="point_alias")
def point_alias(media_type: str, target_alias: str, index_suffix: str) -> None:
"""
Alias the index with the given suffix to the target_alias, first removing the
target_alias from any other indices to which it is linked.
"""
point_alias_payload = {
"alias": target_alias,
"index_suffix": index_suffix,
}

ingestion_server.trigger_and_wait_for_task(
action="POINT_ALIAS",
model=media_type,
data=point_alias_payload,
timeout=timedelta(hours=12), # matches the ingestion server's wait time
http_conn_id="staging_data_refresh",
)
Loading

0 comments on commit ee77ef5

Please sign in to comment.