Skip to content

Commit

Permalink
Only refresh index once, clean up es task params
Browse files Browse the repository at this point in the history
  • Loading branch information
stacimc committed Feb 28, 2024
1 parent 72a922b commit f4613ba
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 28 deletions.
48 changes: 25 additions & 23 deletions catalog/dags/common/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from airflow.decorators import task, task_group
from airflow.models.connection import Connection
from airflow.providers.elasticsearch.hooks.elasticsearch import ElasticsearchPythonHook
from airflow.sensors.python import PythonSensor
from airflow.sensors.base import PokeReturnValue
from airflow.utils.trigger_rule import TriggerRule

from common.constants import REFRESH_POKE_INTERVAL
Expand All @@ -16,7 +16,7 @@

# Index settings that should not be copied over from the base configuration when
# creating a new index.
EXCLUDED_INDEX_SETTINGS = ["provided_name", "creation_date", "uuid", "version"]
EXCLUDED_INDEX_SETTINGS = {"provided_name", "creation_date", "uuid", "version"}


@task
Expand Down Expand Up @@ -88,13 +88,14 @@ def create_index(index_config, es_host: str):

@task_group(group_id="trigger_and_wait_for_reindex")
def trigger_and_wait_for_reindex(
es_host: str,
destination_index: str,
source_index: str,
query: dict,
timeout: timedelta,
requests_per_second: int,
es_host: str,
query: dict | None = None,
max_docs: int | None = None,
refresh: bool = True,
slices: Union[int, Literal["auto"]] = "auto",
):
@task
Expand Down Expand Up @@ -122,15 +123,20 @@ def trigger_reindex(
slices=slices,
# Do not hold the slot while awaiting completion
wait_for_completion=False,
# Immediately refresh the index after completion to make
# Whether to immediately refresh the index after completion to make
# the data available for search
refresh=True,
refresh=refresh,
# Throttle
requests_per_second=requests_per_second,
)
return response["task"]

def _wait_for_reindex(task_id: str, expected_docs: int, es_host: str):
@task.sensor(
poke_interval=REFRESH_POKE_INTERVAL, timeout=timeout, mode="reschedule"
)
def wait_for_reindex(
es_host: str, task_id: str, expected_docs: int | None = None
) -> PokeReturnValue:
es_conn = ElasticsearchPythonHook(hosts=[es_host]).get_conn

response = es_conn.tasks.get(task_id=task_id)
Expand All @@ -142,7 +148,8 @@ def _wait_for_reindex(task_id: str, expected_docs: int, es_host: str):
)
else:
logger.info(f"Reindexed {count} documents.")
return response.get("completed")

return PokeReturnValue(is_done=response.get("completed") is True)

trigger_reindex_task = trigger_reindex(
es_host,
Expand All @@ -154,19 +161,17 @@ def _wait_for_reindex(task_id: str, expected_docs: int, es_host: str):
slices,
)

wait_for_reindex = PythonSensor(
task_id="wait_for_reindex",
python_callable=_wait_for_reindex,
timeout=timeout,
poke_interval=REFRESH_POKE_INTERVAL,
op_kwargs={
"task_id": trigger_reindex_task,
"expected_docs": max_docs,
"es_host": es_host,
},
wait_for_reindex_task = wait_for_reindex(
task_id=trigger_reindex_task, expected_docs=max_docs, es_host=es_host
)

trigger_reindex_task >> wait_for_reindex
trigger_reindex_task >> wait_for_reindex_task


@task
def refresh_index(es_host: str, index_name: str):
es_conn = ElasticsearchPythonHook(hosts=[es_host]).get_conn
return es_conn.indices.refresh(index=index_name)


@task_group(group_id="point_alias")
Expand Down Expand Up @@ -209,10 +214,7 @@ def point_new_alias(
return response.get("acknowledged")

exists_alias = check_if_alias_exists(alias, es_host)

remove_alias = remove_existing_alias.override(task_id="remove_existing_alias")(
alias, es_host
)
remove_alias = remove_existing_alias(alias, es_host)

point_alias = point_new_alias.override(
# The remove_alias task may be skipped.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from datetime import datetime, timedelta

from airflow.decorators import dag
from airflow.models import Variable
from airflow.models.param import Param

from common import elasticsearch as es
Expand Down Expand Up @@ -157,15 +156,19 @@ def create_proportional_by_source_staging_index():
destination_index=destination_index_name,
source_index=source_index_name,
timeout=timedelta(hours=12),
requests_per_second=Variable.get(
"ES_INDEX_THROTTLING_RATE", 20_000, deserialize_json=True
),
requests_per_second="{{ var.value.get('ES_INDEX_THROTTLING_RATE', 20_000) }}",
# When slices are used to parallelize indexing, max_docs does
# not work reliably and the final proportions may be incorrect.
slices=None,
# Do not refresh the index after each partial reindex
refresh=False,
es_host=es_host,
).expand_kwargs(desired_source_counts)

refresh_destination_index = es.refresh_index(
index_name=destination_index_name, es_host=es_host
)

point_alias = es.point_alias(
index_name=destination_index_name, alias=destination_alias, es_host=es_host
)
Expand All @@ -182,7 +185,7 @@ def create_proportional_by_source_staging_index():
es_host >> [source_index_name, destination_index_name, destination_alias]
staging_source_counts >> desired_source_counts
new_index >> staging_source_counts
reindex >> point_alias >> notify_completion
reindex >> refresh_destination_index >> point_alias >> notify_completion


create_proportional_by_source_staging_index()

0 comments on commit f4613ba

Please sign in to comment.