From bed1affbb66345ada0de9435828a6519ff76716e Mon Sep 17 00:00:00 2001 From: Karen Braganza Date: Wed, 27 Nov 2024 05:55:37 -0500 Subject: [PATCH] Check pool_slots on partial task import instead of execution (#39724) (#42693) Co-authored-by: Ryan Hatter <25823361+RNHTTR@users.noreply.github.com> Co-authored-by: Utkarsh Sharma --- airflow/decorators/base.py | 6 ++++++ airflow/models/baseoperator.py | 5 +++++ tests/models/test_mappedoperator.py | 9 +++++++++ 3 files changed, 20 insertions(+) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index d743acbe50b2b..bcb64aaa6eb3c 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -457,6 +457,12 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None)) if partial_kwargs.get("pool") is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME + if "pool_slots" in partial_kwargs: + if partial_kwargs["pool_slots"] < 1: + dag_str = "" + if dag: + dag_str = f" in dag {dag.dag_id}" + raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES)) partial_kwargs["retry_delay"] = coerce_timedelta( partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 11522060fe06a..773552184f103 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -365,6 +365,11 @@ def partial( partial_kwargs["end_date"] = timezone.convert_to_utc(partial_kwargs["end_date"]) if partial_kwargs["pool"] is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME + if partial_kwargs["pool_slots"] < 1: + dag_str = "" + if dag: + dag_str = f" in dag {dag.dag_id}" + raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"]) partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") if partial_kwargs["max_retry_delay"] is not None: diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 01991c0bb457d..cf547912fb924 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -221,6 +221,15 @@ def test_partial_on_class_invalid_ctor_args() -> None: MockOperator.partial(task_id="a", foo="bar", bar=2) +def test_partial_on_invalid_pool_slots_raises() -> None: + """Test that when we pass an invalid value to pool_slots in partial(), + + i.e. if the value is not an integer, an error is raised at import time.""" + + with pytest.raises(TypeError, match="'<' not supported between instances of 'str' and 'int'"): + MockOperator.partial(task_id="pool_slots_test", pool="test", pool_slots="a").expand(arg1=[1, 2, 3]) + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode @pytest.mark.parametrize( ["num_existing_tis", "expected"],