diff --git a/catalog/dags/common/sensors/single_run_external_dags_sensor.py b/catalog/dags/common/sensors/single_run_external_dags_sensor.py index 81afd0a7114..a6279b6dce3 100644 --- a/catalog/dags/common/sensors/single_run_external_dags_sensor.py +++ b/catalog/dags/common/sensors/single_run_external_dags_sensor.py @@ -21,6 +21,8 @@ class SingleRunExternalDAGsSensor(BaseSensorOperator): :param external_dag_ids: A list of dag_ids that you want to wait for :param check_existence: Set to `True` to check if the external DAGs exist, and immediately cease waiting if not (default value: False). + :param allow_concurrent_runs: Used to force the Sensor to pass, even + if there are concurrent runs. """ def __init__( @@ -28,12 +30,16 @@ def __init__( *, external_dag_ids: Iterable[str], check_existence: bool = False, + allow_concurrent_runs: bool = False, **kwargs, ): super().__init__(**kwargs) self.external_dag_ids = external_dag_ids self.check_existence = check_existence - self._has_checked_existence = False + self.allow_concurrent_runs = allow_concurrent_runs + + # Used to ensure some checks are only evaluated on the first poke + self._has_checked_params = False @provide_session def poke(self, context, session=None): @@ -42,8 +48,20 @@ def poke(self, context, session=None): self.external_dag_ids, ) - if self.check_existence: - self._check_for_existence(session=session) + if not self._has_checked_params: + if self.allow_concurrent_runs: + self.log.info( + "`allow_concurrent_runs` is enabled. Returning without" + " checking for running DAGs." + ) + return True + + if self.check_existence: + self._check_for_existence(session=session) + + # Only check DAG existence and `allow_concurrent_runs` + # on the first execution. + self._has_checked_params = True count_running = self.get_count(session) @@ -51,10 +69,6 @@ def poke(self, context, session=None): return count_running == 0 def _check_for_existence(self, session) -> None: - # Check DAG existence only once, on the first execution. - if self._has_checked_existence: - return - for dag_id in self.external_dag_ids: dag_to_wait = ( session.query(DagModel).filter(DagModel.dag_id == dag_id).first() @@ -72,7 +86,6 @@ def _check_for_existence(self, session) -> None: f"The external DAG {dag_id} does not have a task " f"with id {self.task_id}." ) - self._has_checked_existence = True def get_count(self, session) -> int: # Get the count of running DAGs. A DAG is considered 'running' if diff --git a/catalog/dags/data_refresh/dag_factory.py b/catalog/dags/data_refresh/dag_factory.py index df3f9a9557b..cb66c92c406 100644 --- a/catalog/dags/data_refresh/dag_factory.py +++ b/catalog/dags/data_refresh/dag_factory.py @@ -70,12 +70,14 @@ def wait_for_conflicting_dags( data_refresh_config: DataRefreshConfig, external_dag_ids: list[str], concurrency_tag: str, + allow_concurrent_data_refreshes: bool, ): # Wait to ensure that no other Data Refresh DAGs are running. SingleRunExternalDAGsSensor( task_id="wait_for_data_refresh", external_dag_ids=external_dag_ids, check_existence=True, + allow_concurrent_runs=allow_concurrent_data_refreshes, poke_interval=data_refresh_config.concurrency_check_poke_interval, mode="reschedule", pool=DATA_REFRESH_POOL, @@ -152,7 +154,17 @@ def create_data_refresh_dag( "Optional suffix appended to the `media_type` in the Elasticsearch index" " name. If not supplied, a uuid is used." ), - ) + ), + "allow_concurrent_data_refreshes": Param( + default=False, + type="boolean", + description=( + "Whether to allow multiple data refresh DAGs for the given environment" + " to run concurrently. This setting should be enabled with extreme" + " caution, as reindexing multiple large Elasticsearch indices" + " simultaneously should be avoided." + ), + ), }, ) @@ -169,7 +181,10 @@ def create_data_refresh_dag( ) wait_for_dags = wait_for_conflicting_dags( - data_refresh_config, external_dag_ids, concurrency_tag + data_refresh_config, + external_dag_ids, + concurrency_tag, + "{{ params.allow_concurrent_data_refreshes }}", ) copy_data = copy_upstream_tables( diff --git a/catalog/tests/dags/common/sensors/test_single_run_external_dags_sensor.py b/catalog/tests/dags/common/sensors/test_single_run_external_dags_sensor.py index 6b11acd9890..a544f92ad85 100644 --- a/catalog/tests/dags/common/sensors/test_single_run_external_dags_sensor.py +++ b/catalog/tests/dags/common/sensors/test_single_run_external_dags_sensor.py @@ -212,8 +212,25 @@ def test_succeeds_if_no_running_dags( "ignore:This class is deprecated. Please use " "`airflow.utils.task_group.TaskGroup`.:airflow.exceptions.RemovedInAirflow3Warning" ) +@pytest.mark.parametrize( + "allow_concurrent_runs, expected_message", + [ + (False, "1 DAGs are in the running state"), + ( + True, + "`allow_concurrent_runs` is enabled. Returning without checking for" + " running DAGs.", + ), + ], +) def test_retries_if_running_dags_with_completed_sensor_task( - caplog, sample_dag_id_fixture, sample_pool_fixture, clean_db, setup_pool + allow_concurrent_runs, + expected_message, + caplog, + sample_dag_id_fixture, + sample_pool_fixture, + clean_db, + setup_pool, ): # Create a DAG in the 'running' state running_dag = create_dag("running_dag", sample_dag_id_fixture, sample_pool_fixture) @@ -236,7 +253,7 @@ def test_retries_if_running_dags_with_completed_sensor_task( # Create the Test DAG and sensor and set up dependent dag Ids dag = DAG( - "test_dag_failure", + f"test_dag_failure_with_allow_concurrent_runs_{allow_concurrent_runs}", schedule=None, default_args={ "owner": "airflow", @@ -249,6 +266,7 @@ def test_retries_if_running_dags_with_completed_sensor_task( f"{sample_dag_id_fixture}_success_dag", f"{sample_dag_id_fixture}_running_dag", ], + allow_concurrent_runs=allow_concurrent_runs, poke_interval=5, mode="reschedule", dag=dag, @@ -263,4 +281,4 @@ def test_retries_if_running_dags_with_completed_sensor_task( f"{sample_dag_id_fixture}_success_dag', '{sample_dag_id_fixture}_running_dag'] ..." in caplog.text ) - assert "1 DAGs are in the running state" in caplog.text + assert expected_message in caplog.text diff --git a/load_sample_data.sh b/load_sample_data.sh index cc260e46a92..e6fe593b8da 100755 --- a/load_sample_data.sh +++ b/load_sample_data.sh @@ -161,8 +161,8 @@ just catalog/cli airflow dags unpause staging_audio_data_refresh just catalog/cli airflow dags unpause staging_image_data_refresh # Trigger the data refresh dags at the same time. The DAGs will manage # concurrency issues. -just catalog/cli airflow dags trigger staging_audio_data_refresh --conf '{"index_suffix": "init"}' -just catalog/cli airflow dags trigger staging_image_data_refresh --conf '{"index_suffix": "init"}' +just catalog/cli airflow dags trigger staging_audio_data_refresh --conf '{"index_suffix": "init", "allow_concurrent_data_refreshes": true}' +just catalog/cli airflow dags trigger staging_image_data_refresh --conf '{"index_suffix": "init", "allow_concurrent_data_refreshes": true}' # Wait for all relevant indices to be created and promoted just docker/es/wait-for-index "audio" just docker/es/wait-for-count "audio"