From 9214018153dd193be6b1147629f73b23d8195cce Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Fri, 27 May 2022 00:25:13 -0400 Subject: [PATCH] Disallow calling expand with no arguments (#23463) --- airflow/decorators/base.py | 3 +++ airflow/models/mappedoperator.py | 5 +++++ airflow/serialization/serialized_objects.py | 3 ++- tests/api_connexion/endpoints/test_task_endpoint.py | 6 ++++-- tests/models/test_taskinstance.py | 8 ++++++-- 5 files changed, 20 insertions(+), 5 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 79277c7281e11..1b14cd066832e 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -312,6 +312,9 @@ def _validate_arg_names(self, func: ValidationSource, kwargs: Dict[str, Any]): raise TypeError(f"{func}() got unexpected keyword arguments {names}") def expand(self, **map_kwargs: "Mappable") -> XComArg: + if not map_kwargs: + raise TypeError("no arguments to expand against") + self._validate_arg_names("expand", map_kwargs) prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping already partial") ensure_xcomarg_return_value(map_kwargs) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index c522cefb2c956..663ceeece16a0 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -191,6 +191,11 @@ def __del__(self): warnings.warn(f"Task {task_id} was never mapped!") def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator": + if not mapped_kwargs: + raise TypeError("no arguments to expand against") + return self._expand(**mapped_kwargs) + + def _expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator": self._expand_called = True from airflow.operators.empty import EmptyOperator diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 8d21bca8ee915..3e674b2f8d0a2 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -96,7 +96,8 @@ def _get_default_mapped_partial() -> Dict[str, Any]: are defaults, they are automatically supplied on de-serialization, so we don't need to store them. """ - default_partial_kwargs = BaseOperator.partial(task_id="_").expand().partial_kwargs + # Use the private _expand() method to avoid the empty kwargs check. + default_partial_kwargs = BaseOperator.partial(task_id="_")._expand().partial_kwargs return BaseSerialization._serialize(default_partial_kwargs)[Encoding.VAR] diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py index 9748305d8c0af..7509a89032b95 100644 --- a/tests/api_connexion/endpoints/test_task_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_endpoint.py @@ -67,8 +67,10 @@ def setup_dag(self, configured_app): task2 = EmptyOperator(task_id=self.task_id2, start_date=self.task2_start_date) with DAG(self.mapped_dag_id, start_date=self.task1_start_date) as mapped_dag: - task3 = EmptyOperator(task_id=self.task_id3) # noqa - mapped_task = EmptyOperator.partial(task_id=self.mapped_task_id).expand() # noqa + EmptyOperator(task_id=self.task_id3) + # Use the private _expand() method to avoid the empty kwargs check. + # We don't care about how the operator runs here, only its presence. + EmptyOperator.partial(task_id=self.mapped_task_id)._expand() task1 >> task2 dag_bag = DagBag(os.devnull, include_examples=False) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 975d44cf922de..a1d180fa1ee56 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1053,7 +1053,9 @@ def test_xcom_pull(self, dag_maker): def test_xcom_pull_mapped(self, dag_maker, session): with dag_maker(dag_id="test_xcom", session=session): - task_1 = EmptyOperator.partial(task_id="task_1").expand() + # Use the private _expand() method to avoid the empty kwargs check. + # We don't care about how the operator runs here, only its presence. + task_1 = EmptyOperator.partial(task_id="task_1")._expand() EmptyOperator(task_id="task_2") dagrun = dag_maker.create_dagrun(start_date=timezone.datetime(2016, 6, 1, 0, 0, 0)) @@ -2763,7 +2765,9 @@ def cmds(): def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_value, dag_maker, session): """Ensure we access XCom lazily when pulling from a mapped operator.""" with dag_maker(dag_id="test_xcom", session=session): - task_1 = EmptyOperator.partial(task_id="task_1").expand() + # Use the private _expand() method to avoid the empty kwargs check. + # We don't care about how the operator runs here, only its presence. + task_1 = EmptyOperator.partial(task_id="task_1")._expand() EmptyOperator(task_id="task_2") dagrun = dag_maker.create_dagrun()