Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disallow calling expand with no arguments #23463

Merged
merged 1 commit into from
May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ def _validate_arg_names(self, func: ValidationSource, kwargs: Dict[str, Any]):
raise TypeError(f"{func}() got unexpected keyword arguments {names}")

def expand(self, **map_kwargs: "Mappable") -> XComArg:
if not map_kwargs:
raise TypeError("no arguments to expand against")

self._validate_arg_names("expand", map_kwargs)
prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping already partial")
ensure_xcomarg_return_value(map_kwargs)
Expand Down
5 changes: 5 additions & 0 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ def __del__(self):
warnings.warn(f"Task {task_id} was never mapped!")

def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
if not mapped_kwargs:
raise TypeError("no arguments to expand against")
return self._expand(**mapped_kwargs)

def _expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
self._expand_called = True

from airflow.operators.empty import EmptyOperator
Expand Down
3 changes: 2 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def _get_default_mapped_partial() -> Dict[str, Any]:
are defaults, they are automatically supplied on de-serialization, so we
don't need to store them.
"""
default_partial_kwargs = BaseOperator.partial(task_id="_").expand().partial_kwargs
# Use the private _expand() method to avoid the empty kwargs check.
default_partial_kwargs = BaseOperator.partial(task_id="_")._expand().partial_kwargs
return BaseSerialization._serialize(default_partial_kwargs)[Encoding.VAR]


Expand Down
6 changes: 4 additions & 2 deletions tests/api_connexion/endpoints/test_task_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ def setup_dag(self, configured_app):
task2 = EmptyOperator(task_id=self.task_id2, start_date=self.task2_start_date)

with DAG(self.mapped_dag_id, start_date=self.task1_start_date) as mapped_dag:
task3 = EmptyOperator(task_id=self.task_id3) # noqa
mapped_task = EmptyOperator.partial(task_id=self.mapped_task_id).expand() # noqa
EmptyOperator(task_id=self.task_id3)
# Use the private _expand() method to avoid the empty kwargs check.
# We don't care about how the operator runs here, only its presence.
EmptyOperator.partial(task_id=self.mapped_task_id)._expand()

task1 >> task2
dag_bag = DagBag(os.devnull, include_examples=False)
Expand Down
8 changes: 6 additions & 2 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,7 +1053,9 @@ def test_xcom_pull(self, dag_maker):

def test_xcom_pull_mapped(self, dag_maker, session):
with dag_maker(dag_id="test_xcom", session=session):
task_1 = EmptyOperator.partial(task_id="task_1").expand()
# Use the private _expand() method to avoid the empty kwargs check.
# We don't care about how the operator runs here, only its presence.
task_1 = EmptyOperator.partial(task_id="task_1")._expand()
EmptyOperator(task_id="task_2")

dagrun = dag_maker.create_dagrun(start_date=timezone.datetime(2016, 6, 1, 0, 0, 0))
Expand Down Expand Up @@ -2763,7 +2765,9 @@ def cmds():
def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_value, dag_maker, session):
"""Ensure we access XCom lazily when pulling from a mapped operator."""
with dag_maker(dag_id="test_xcom", session=session):
task_1 = EmptyOperator.partial(task_id="task_1").expand()
# Use the private _expand() method to avoid the empty kwargs check.
# We don't care about how the operator runs here, only its presence.
task_1 = EmptyOperator.partial(task_id="task_1")._expand()
EmptyOperator(task_id="task_2")

dagrun = dag_maker.create_dagrun()
Expand Down