Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add distributed reindex steps #4572

Merged
merged 20 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions catalog/dags/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
AWS_RDS_CONN_ID = os.environ.get("AWS_RDS_CONN_ID", AWS_CONN_ID)
ES_PROD_HTTP_CONN_ID = "elasticsearch_http_production"
REFRESH_POKE_INTERVAL = int(os.getenv("DATA_REFRESH_POKE_INTERVAL", 60 * 30))
DATA_REFRESH_POOL = os.getenv("DATA_REFRESH_POOL", "data_refresh")


@dataclass
Expand Down
18 changes: 18 additions & 0 deletions catalog/dags/common/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,24 @@ def remove_excluded_index_settings(index_config):
return index_config


@task
def get_index_configuration_copy(
source_index: str, target_index_name: str, es_host: str
):
stacimc marked this conversation as resolved.
Show resolved Hide resolved
"""
Create a new index configuration based off the `source_index` but with the given
`target_index_name`, in the format needed for `create_index`. Removes fields that
should not be copied into a new index configuration such as the uuid.
"""
base_config = get_index_configuration.function(source_index, es_host)

cleaned_config = remove_excluded_index_settings(base_config)

cleaned_config["index"] = target_index_name

return cleaned_config


@task
def get_record_count_group_by_sources(es_host: str, index: str):
"""
Expand Down
11 changes: 11 additions & 0 deletions catalog/dags/common/operators/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from airflow.providers.http.operators.http import HttpOperator


class TemplatedConnectionHttpOperator(HttpOperator):
"""
Wrapper around the HTTPOperator which allows templating of the conn_id,
in order to support using a conn_id passed through XCOMs.
"""

# Extended to allow templating of conn_id
template_fields = HttpOperator.template_fields + ("http_conn_id",)
11 changes: 11 additions & 0 deletions catalog/dags/common/sensors/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from airflow.providers.http.sensors.http import HttpSensor


class TemplatedConnectionHttpSensor(HttpSensor):
"""
Wrapper around the HTTPSensor which allows templating of the conn_id,
in order to support using a conn_id passed through XCOMs.
"""

# Extended to allow templating of conn_id
template_fields = HttpSensor.template_fields + ("http_conn_id",)
6 changes: 3 additions & 3 deletions catalog/dags/data_refresh/alter_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,12 @@ def report(counts: list[int]):

@task_group(group_id="alter_table_data")
def alter_table_data(
environment: Environment,
target_environment: Environment,
data_refresh_config: DataRefreshConfig,
):
"""Perform data altering across a number of tasks."""
postgres_conn_id = POSTGRES_API_CONN_IDS.get(environment)
temp_table = data_refresh_config.table_mappings[0].temp_table_name
postgres_conn_id = POSTGRES_API_CONN_IDS.get(target_environment)
temp_table = data_refresh_config.table_mapping.temp_table_name

estimated_record_count = PGExecuteQueryOperator(
task_id="get_estimated_record_count",
Expand Down
9 changes: 9 additions & 0 deletions catalog/dags/data_refresh/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from common.constants import PRODUCTION, STAGING


INDEXER_WORKER_COUNTS = {STAGING: 2, PRODUCTION: 6}

INDEXER_LAUNCH_TEMPLATES = {
STAGING: "indexer-worker-pool-s",
PRODUCTION: "indexer-worker-pool-p",
}
7 changes: 3 additions & 4 deletions catalog/dags/data_refresh/copy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def copy_data(
def copy_upstream_table(
upstream_conn_id: str,
downstream_conn_id: str,
environment: Environment,
timeout: timedelta,
limit: int,
upstream_table_name: str,
Expand Down Expand Up @@ -275,12 +274,13 @@ def copy_upstream_table(
create_temp_table >> setup_id_columns >> setup_tertiary_columns
setup_tertiary_columns >> copy
copy >> add_primary_key

return


@task_group(group_id="copy_upstream_tables")
def copy_upstream_tables(
environment: Environment, data_refresh_config: DataRefreshConfig
target_environment: Environment, data_refresh_config: DataRefreshConfig
):
"""
For each upstream table associated with the given media type, create a new
Expand All @@ -290,7 +290,7 @@ def copy_upstream_tables(
This task does _not_ apply all indices and constraints, merely copies
the data.
"""
downstream_conn_id = POSTGRES_API_CONN_IDS.get(environment)
downstream_conn_id = POSTGRES_API_CONN_IDS.get(target_environment)
upstream_conn_id = POSTGRES_CONN_ID

create_fdw = _run_sql.override(task_id="create_fdw")(
Expand All @@ -309,7 +309,6 @@ def copy_upstream_tables(
copy_tables = copy_upstream_table.partial(
upstream_conn_id=upstream_conn_id,
downstream_conn_id=downstream_conn_id,
environment=environment,
timeout=data_refresh_config.copy_data_timeout,
limit=limit,
).expand_kwargs([asdict(tm) for tm in data_refresh_config.table_mappings])
Expand Down
46 changes: 46 additions & 0 deletions catalog/dags/data_refresh/create_and_promote_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
# Create and Promote Index

This file contains TaskGroups related to creating and promoting Elasticsearch indices
as part of the Data Refresh.
"""

import logging
import uuid

from airflow.decorators import task, task_group

from common import elasticsearch as es
from data_refresh.data_refresh_types import DataRefreshConfig


logger = logging.getLogger(__name__)


@task
def generate_index_name(media_type: str) -> str:
return f"{media_type}-{uuid.uuid4().hex}"


@task_group(group_id="create_temp_index")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yay for being able to do this with decorators!

def create_index(
data_refresh_config: DataRefreshConfig,
es_host: str,
):
# Generate a UUID suffix that will be used by the newly created index.
temp_index_name = generate_index_name(media_type=data_refresh_config.media_type)

# Get the configuration for the new Elasticsearch index, based off the existing index.
index_config = es.get_index_configuration_copy.override(
task_id="get_index_configuration"
)(
source_index=data_refresh_config.media_type,
target_index_name=temp_index_name,
es_host=es_host,
)

# Create a new index matching the existing configuration
es.create_index(index_config=index_config, es_host=es_host)

# Return the name of the created index
return temp_index_name
71 changes: 44 additions & 27 deletions catalog/dags/data_refresh/dag_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
"""

import logging
import os
from collections.abc import Sequence
from itertools import product

Expand All @@ -38,22 +37,26 @@

from common import cloudwatch
from common import elasticsearch as es
from common.constants import DAG_DEFAULT_ARGS, ENVIRONMENTS, Environment
from common.constants import (
DAG_DEFAULT_ARGS,
DATA_REFRESH_POOL,
ENVIRONMENTS,
Environment,
)
from common.sensors.constants import ES_CONCURRENCY_TAGS
from common.sensors.single_run_external_dags_sensor import SingleRunExternalDAGsSensor
from common.sensors.utils import wait_for_external_dags_with_tag
from data_refresh.alter_data import alter_table_data
from data_refresh.copy_data import copy_upstream_tables
from data_refresh.create_and_promote_index import create_index
from data_refresh.data_refresh_types import DATA_REFRESH_CONFIGS, DataRefreshConfig
from data_refresh.distributed_reindex import perform_distributed_reindex
from data_refresh.reporting import report_record_difference


logger = logging.getLogger(__name__)


DATA_REFRESH_POOL = os.getenv("DATA_REFRESH_POOL", "data_refresh")


@task_group(group_id="wait_for_conflicting_dags")
def wait_for_conflicting_dags(
data_refresh_config: DataRefreshConfig,
Expand All @@ -65,7 +68,7 @@ def wait_for_conflicting_dags(
task_id="wait_for_data_refresh",
external_dag_ids=external_dag_ids,
check_existence=True,
poke_interval=data_refresh_config.data_refresh_poke_interval,
poke_interval=data_refresh_config.concurrency_check_poke_interval,
mode="reschedule",
pool=DATA_REFRESH_POOL,
)
Expand All @@ -86,7 +89,7 @@ def wait_for_conflicting_dags(

def create_data_refresh_dag(
data_refresh_config: DataRefreshConfig,
environment: Environment,
target_environment: Environment,
external_dag_ids: Sequence[str],
):
"""
Expand All @@ -96,40 +99,40 @@ def create_data_refresh_dag(

Required Arguments:

data_refresh: dataclass containing configuration information for the
DAG
environment: the environment in which the data refresh is performed
external_dag_ids: list of ids of the other data refresh DAGs. The data refresh step
of this DAG will not run concurrently with the corresponding step
of any dependent DAG.
data_refresh: dataclass containing configuration information for the
DAG
target_environment: the API environment in which the data refresh is performed
external_dag_ids: list of ids of the other data refresh DAGs. The data refresh step
of this DAG will not run concurrently with the corresponding step
of any dependent DAG.
"""
default_args = {
**DAG_DEFAULT_ARGS,
**data_refresh_config.default_args,
}

concurrency_tag = ES_CONCURRENCY_TAGS.get(environment)
concurrency_tag = ES_CONCURRENCY_TAGS.get(target_environment)

dag = DAG(
dag_id=f"{environment}_{data_refresh_config.dag_id}",
dag_id=f"{target_environment}_{data_refresh_config.dag_id}",
dagrun_timeout=data_refresh_config.dag_timeout,
default_args=default_args,
start_date=data_refresh_config.start_date,
schedule=data_refresh_config.schedule,
render_template_as_native_obj=True,
max_active_runs=1,
catchup=False,
doc_md=__doc__,
tags=[
"data_refresh",
f"{environment}_data_refresh",
f"{target_environment}_data_refresh",
concurrency_tag,
],
render_template_as_native_obj=True,
)

with dag:
# Connect to the appropriate Elasticsearch cluster
es_host = es.get_es_host(environment=environment)
es_host = es.get_es_host(environment=target_environment)

# Get the current number of records in the target API table
before_record_count = es.get_record_count_group_by_sources.override(
Expand All @@ -144,11 +147,19 @@ def create_data_refresh_dag(
)

copy_data = copy_upstream_tables(
environment=environment, data_refresh_config=data_refresh_config
target_environment=target_environment,
data_refresh_config=data_refresh_config,
)

alter_data = alter_table_data(
environment=environment, data_refresh_config=data_refresh_config
target_environment=target_environment,
data_refresh_config=data_refresh_config,
)

# Create a new temporary index based off the configuration of the existing media index.
# This will later replace the live index.
target_index = create_index(
data_refresh_config=data_refresh_config, es_host=es_host
)

# Disable Cloudwatch alarms that are noisy during the reindexing steps of a
Expand All @@ -161,8 +172,13 @@ def create_data_refresh_dag(
},
)

# TODO create_and_populate_index
# (TaskGroup that creates index, triggers and waits for reindexing)
# Populate the Elasticsearch index.
reindex = perform_distributed_reindex(
environment="{{ var.value.ENVIRONMENT }}",
target_environment=target_environment,
target_index=target_index,
data_refresh_config=data_refresh_config,
)

# TODO create_and_populate_filtered_index

Expand Down Expand Up @@ -207,10 +223,11 @@ def create_data_refresh_dag(
>> wait_for_dags
>> copy_data
>> alter_data
>> target_index
>> disable_alarms
>> reindex
)
# TODO: this will include reindex/etc once added
disable_alarms >> [enable_alarms, after_record_count]
reindex >> [enable_alarms, after_record_count]
after_record_count >> report_counts

return dag
Expand All @@ -219,14 +236,14 @@ def create_data_refresh_dag(
# Generate data refresh DAGs for each DATA_REFRESH_CONFIG, per environment.
all_data_refresh_dag_ids = {refresh.dag_id for refresh in DATA_REFRESH_CONFIGS.values()}

for data_refresh_config, environment in product(
for data_refresh_config, target_environment in product(
DATA_REFRESH_CONFIGS.values(), ENVIRONMENTS
):
# Construct a set of all data refresh DAG ids other than the current DAG
other_dag_ids = all_data_refresh_dag_ids - {data_refresh_config.dag_id}

globals()[data_refresh_config.dag_id] = create_data_refresh_dag(
data_refresh_config,
environment,
[f"{environment}_{dag_id}" for dag_id in other_dag_ids],
target_environment,
[f"{target_environment}_{dag_id}" for dag_id in other_dag_ids],
)
Loading