From fbe8919bdcd245b1add811217c4d283b3fe49c8b Mon Sep 17 00:00:00 2001 From: Zhenglin Li <1125806272@qq.com> Date: Fri, 16 Feb 2024 11:39:00 -0600 Subject: [PATCH 1/3] feat: Fall back to datasets' "extra" if arg extra is empty --- airflow/datasets/manager.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index 047f8494a481d..75d08121b50a2 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -64,6 +64,9 @@ def register_dataset_change( For local datasets, look them up, record the dataset event, queue dagruns, and broadcast the dataset event """ + if extra is None and dataset.extra: + extra = dataset.extra + dataset_model = session.scalar( select(DatasetModel) .where(DatasetModel.uri == dataset.uri) From b5fa4c95aae42709a1813eabf97763e12edbe1bc Mon Sep 17 00:00:00 2001 From: Zhenglin Li <1125806272@qq.com> Date: Fri, 16 Feb 2024 16:39:37 -0600 Subject: [PATCH 2/3] feat: add test cases to test extra --- tests/datasets/test_manager.py | 40 +++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py index 514ed8877ac75..26d907ae3c1f8 100644 --- a/tests/datasets/test_manager.py +++ b/tests/datasets/test_manager.py @@ -66,7 +66,7 @@ def test_register_dataset_change_dataset_doesnt_exist(self, mock_task_instance): mock_session.add.assert_not_called() mock_session.merge.assert_not_called() - def test_register_dataset_change(self, session, dag_maker, mock_task_instance): + def register_dataset_change_without_extra(self, session, dag_maker, mock_task_instance): dsem = DatasetManager() ds = Dataset(uri="test_dataset_uri") @@ -85,6 +85,44 @@ def test_register_dataset_change(self, session, dag_maker, mock_task_instance): assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1 assert session.query(DatasetDagRunQueue).count() == 2 + def test_register_dataset_change_with_dataset_extra(self, session, dag_maker, mock_task_instance): + dsem = DatasetManager() + + ds = Dataset(uri="test_dataset_uri", extra={"hi": "bye"}) + dag1 = DagModel(dag_id="dag1") + dag2 = DagModel(dag_id="dag2") + session.add_all([dag1, dag2]) + + dsm = DatasetModel(uri="test_dataset_uri") + session.add(dsm) + dsm.consuming_dags = [DagScheduleDatasetReference(dag_id=dag.dag_id) for dag in (dag1, dag2)] + session.flush() + + dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session) + + # Ensure we've created a dataset + assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1 + assert session.query(DatasetDagRunQueue).count() == 2 + + def test_register_dataset_change_with_extra(self, session, dag_maker, mock_task_instance): + dsem = DatasetManager() + + ds = Dataset(uri="test_dataset_uri") + dag1 = DagModel(dag_id="dag1") + dag2 = DagModel(dag_id="dag2") + session.add_all([dag1, dag2]) + + dsm = DatasetModel(uri="test_dataset_uri") + session.add(dsm) + dsm.consuming_dags = [DagScheduleDatasetReference(dag_id=dag.dag_id) for dag in (dag1, dag2)] + session.flush() + + dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session, extra={"hi": "bye"}) + + # Ensure we've created a dataset + assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1 + assert session.query(DatasetDagRunQueue).count() == 2 + def test_register_dataset_change_no_downstreams(self, session, mock_task_instance): dsem = DatasetManager() From 8bdd3795c6d4ae6baa96792a221bc7c17e03ed37 Mon Sep 17 00:00:00 2001 From: Zhenglin Li <1125806272@qq.com> Date: Fri, 16 Feb 2024 17:08:01 -0600 Subject: [PATCH 3/3] feat: add test cases to test extra --- tests/datasets/test_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py index 26d907ae3c1f8..7ca90c9c8dee5 100644 --- a/tests/datasets/test_manager.py +++ b/tests/datasets/test_manager.py @@ -117,7 +117,9 @@ def test_register_dataset_change_with_extra(self, session, dag_maker, mock_task_ dsm.consuming_dags = [DagScheduleDatasetReference(dag_id=dag.dag_id) for dag in (dag1, dag2)] session.flush() - dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session, extra={"hi": "bye"}) + dsem.register_dataset_change( + task_instance=mock_task_instance, dataset=ds, session=session, extra={"hi": "bye"} + ) # Ensure we've created a dataset assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1