Skip to content

Commit

Permalink
Disallow calling expand with no arguments (#23463)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored May 27, 2022
1 parent c9b21b8 commit 9214018
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 5 deletions.
3 changes: 3 additions & 0 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ def _validate_arg_names(self, func: ValidationSource, kwargs: Dict[str, Any]):
raise TypeError(f"{func}() got unexpected keyword arguments {names}")

def expand(self, **map_kwargs: "Mappable") -> XComArg:
if not map_kwargs:
raise TypeError("no arguments to expand against")

self._validate_arg_names("expand", map_kwargs)
prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping already partial")
ensure_xcomarg_return_value(map_kwargs)
Expand Down
5 changes: 5 additions & 0 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ def __del__(self):
warnings.warn(f"Task {task_id} was never mapped!")

def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
if not mapped_kwargs:
raise TypeError("no arguments to expand against")
return self._expand(**mapped_kwargs)

def _expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
self._expand_called = True

from airflow.operators.empty import EmptyOperator
Expand Down
3 changes: 2 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def _get_default_mapped_partial() -> Dict[str, Any]:
are defaults, they are automatically supplied on de-serialization, so we
don't need to store them.
"""
default_partial_kwargs = BaseOperator.partial(task_id="_").expand().partial_kwargs
# Use the private _expand() method to avoid the empty kwargs check.
default_partial_kwargs = BaseOperator.partial(task_id="_")._expand().partial_kwargs
return BaseSerialization._serialize(default_partial_kwargs)[Encoding.VAR]


Expand Down
6 changes: 4 additions & 2 deletions tests/api_connexion/endpoints/test_task_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ def setup_dag(self, configured_app):
task2 = EmptyOperator(task_id=self.task_id2, start_date=self.task2_start_date)

with DAG(self.mapped_dag_id, start_date=self.task1_start_date) as mapped_dag:
task3 = EmptyOperator(task_id=self.task_id3) # noqa
mapped_task = EmptyOperator.partial(task_id=self.mapped_task_id).expand() # noqa
EmptyOperator(task_id=self.task_id3)
# Use the private _expand() method to avoid the empty kwargs check.
# We don't care about how the operator runs here, only its presence.
EmptyOperator.partial(task_id=self.mapped_task_id)._expand()

task1 >> task2
dag_bag = DagBag(os.devnull, include_examples=False)
Expand Down
8 changes: 6 additions & 2 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,7 +1053,9 @@ def test_xcom_pull(self, dag_maker):

def test_xcom_pull_mapped(self, dag_maker, session):
with dag_maker(dag_id="test_xcom", session=session):
task_1 = EmptyOperator.partial(task_id="task_1").expand()
# Use the private _expand() method to avoid the empty kwargs check.
# We don't care about how the operator runs here, only its presence.
task_1 = EmptyOperator.partial(task_id="task_1")._expand()
EmptyOperator(task_id="task_2")

dagrun = dag_maker.create_dagrun(start_date=timezone.datetime(2016, 6, 1, 0, 0, 0))
Expand Down Expand Up @@ -2763,7 +2765,9 @@ def cmds():
def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_value, dag_maker, session):
"""Ensure we access XCom lazily when pulling from a mapped operator."""
with dag_maker(dag_id="test_xcom", session=session):
task_1 = EmptyOperator.partial(task_id="task_1").expand()
# Use the private _expand() method to avoid the empty kwargs check.
# We don't care about how the operator runs here, only its presence.
task_1 = EmptyOperator.partial(task_id="task_1")._expand()
EmptyOperator(task_id="task_2")

dagrun = dag_maker.create_dagrun()
Expand Down

0 comments on commit 9214018

Please sign in to comment.