Skip to content

Commit

Permalink
Check pool_slots on partial task import instead of execution (#39724) (
Browse files Browse the repository at this point in the history
…#42693)

Co-authored-by: Ryan Hatter <[email protected]>
Co-authored-by: Utkarsh Sharma <[email protected]>
  • Loading branch information
3 people authored Nov 27, 2024
1 parent a361291 commit bed1aff
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
6 changes: 6 additions & 0 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,12 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None))
if partial_kwargs.get("pool") is None:
partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
if "pool_slots" in partial_kwargs:
if partial_kwargs["pool_slots"] < 1:
dag_str = ""
if dag:
dag_str = f" in dag {dag.dag_id}"
raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1")
partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES))
partial_kwargs["retry_delay"] = coerce_timedelta(
partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY),
Expand Down
5 changes: 5 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ def partial(
partial_kwargs["end_date"] = timezone.convert_to_utc(partial_kwargs["end_date"])
if partial_kwargs["pool"] is None:
partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
if partial_kwargs["pool_slots"] < 1:
dag_str = ""
if dag:
dag_str = f" in dag {dag.dag_id}"
raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1")
partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"])
partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay")
if partial_kwargs["max_retry_delay"] is not None:
Expand Down
9 changes: 9 additions & 0 deletions tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,15 @@ def test_partial_on_class_invalid_ctor_args() -> None:
MockOperator.partial(task_id="a", foo="bar", bar=2)


def test_partial_on_invalid_pool_slots_raises() -> None:
"""Test that when we pass an invalid value to pool_slots in partial(),
i.e. if the value is not an integer, an error is raised at import time."""

with pytest.raises(TypeError, match="'<' not supported between instances of 'str' and 'int'"):
MockOperator.partial(task_id="pool_slots_test", pool="test", pool_slots="a").expand(arg1=[1, 2, 3])


@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
@pytest.mark.parametrize(
["num_existing_tis", "expected"],
Expand Down

0 comments on commit bed1aff

Please sign in to comment.