From 72ab491703a55e41068aefdacfa9b6eb1fd17338 Mon Sep 17 00:00:00 2001 From: karen <158095947+karenbraganz@users.noreply.github.com> Date: Mon, 20 May 2024 20:10:28 -0400 Subject: [PATCH 1/7] Check pool_slots in mapped tasks using partial() --- airflow/models/baseoperator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 77423bfc3b99e..0f1038b7c45b1 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -362,6 +362,9 @@ def partial( partial_kwargs["end_date"] = timezone.convert_to_utc(partial_kwargs["end_date"]) if partial_kwargs["pool"] is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME + if partial_kwargs["pool_slots"] < 1: + dag_str = f" in dag {partial_kwargs['dag'].dag_id}" if partial_kwargs['dag'] else "" + raise ValueError(f"pool slots for {partial_kwargs['task_id']}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"]) partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") if partial_kwargs["max_retry_delay"] is not None: From 64141f8515ef2922a206609fe854ca11cbe0ab18 Mon Sep 17 00:00:00 2001 From: karen <158095947+karenbraganz@users.noreply.github.com> Date: Wed, 22 May 2024 19:08:59 -0400 Subject: [PATCH 2/7] Check pool_slots in partial mapped tasks --- airflow/models/baseoperator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 0f1038b7c45b1..396e6e4b2b7b6 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -363,8 +363,8 @@ def partial( if partial_kwargs["pool"] is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME if partial_kwargs["pool_slots"] < 1: - dag_str = f" in dag {partial_kwargs['dag'].dag_id}" if partial_kwargs['dag'] else "" - raise ValueError(f"pool slots for {partial_kwargs['task_id']}{dag_str} cannot be less than 1") + dag_str = f" in dag {partial_kwargs['dag'].dag_id}" if partial_kwargs["dag"] else "" + raise ValueError(f"pool slots for {partial_kwargs['task_id']}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"]) partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") if partial_kwargs["max_retry_delay"] is not None: From 3929f6acc4af72dd08f3187ea2542b75eaf10283 Mon Sep 17 00:00:00 2001 From: karen <158095947+karenbraganz@users.noreply.github.com> Date: Wed, 29 May 2024 10:58:01 -0400 Subject: [PATCH 3/7] Add unit test --- tests/models/test_mappedoperator.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 9f31652424aeb..eb701688d5522 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -218,6 +218,15 @@ def test_partial_on_class_invalid_ctor_args() -> None: MockOperator.partial(task_id="a", foo="bar", bar=2) +def test_partial_on_invalid_pool_slots_raises() -> None: + """Test that when we pass an invalid value to pool_slots in partial(), + + i.e. if the value is not an integer, an error is raised at import time.""" + + with pytest.raises(TypeError, match="'<' not supported between instances of 'str' and 'int'"): + MockOperator.partial(task_id="pool_slots_test", pool="test", pool_slots="a").expand(arg1=[1, 2, 3]) + + @pytest.mark.parametrize( ["num_existing_tis", "expected"], ( From befe10f4904a8cd35047cc69d7fcc8dfb1de32e5 Mon Sep 17 00:00:00 2001 From: karen <158095947+karenbraganz@users.noreply.github.com> Date: Wed, 29 May 2024 11:00:08 -0400 Subject: [PATCH 4/7] Modify _expand() to match partial() --- airflow/decorators/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index d743acbe50b2b..15ed0d8cfc4b3 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -457,6 +457,9 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None)) if partial_kwargs.get("pool") is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME + if partial_kwargs.get("pool_slots") and partial_kwargs["pool_slots"] < 1: + dag_str = f" in dag {partial_kwargs['dag'].dag_id}" if partial_kwargs.get("dag") else "" + raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES)) partial_kwargs["retry_delay"] = coerce_timedelta( partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), From 10ee46da62aeaad24e1a26cf64c7156133038717 Mon Sep 17 00:00:00 2001 From: karen <158095947+karenbraganz@users.noreply.github.com> Date: Thu, 20 Jun 2024 12:51:18 -0400 Subject: [PATCH 5/7] Split lines for readability --- airflow/decorators/base.py | 7 ++++++- airflow/models/baseoperator.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 15ed0d8cfc4b3..55cd3fc9931af 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -458,7 +458,12 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: if partial_kwargs.get("pool") is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME if partial_kwargs.get("pool_slots") and partial_kwargs["pool_slots"] < 1: - dag_str = f" in dag {partial_kwargs['dag'].dag_id}" if partial_kwargs.get("dag") else "" + dag_str = ( + f""" in dag {partial_kwargs['dag'].dag_id + }""" + if partial_kwargs.get("dag") + else "" + ) raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES)) partial_kwargs["retry_delay"] = coerce_timedelta( diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 396e6e4b2b7b6..c334d7dd14385 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -363,7 +363,12 @@ def partial( if partial_kwargs["pool"] is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME if partial_kwargs["pool_slots"] < 1: - dag_str = f" in dag {partial_kwargs['dag'].dag_id}" if partial_kwargs["dag"] else "" + dag_str = ( + f""" in dag {partial_kwargs['dag'].dag_id + }""" + if partial_kwargs["dag"] + else "" + ) raise ValueError(f"pool slots for {partial_kwargs['task_id']}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"]) partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") From 50c741519da1aef4a681e1beb302d9cb2f2ff766 Mon Sep 17 00:00:00 2001 From: karen <158095947+karenbraganz@users.noreply.github.com> Date: Thu, 20 Jun 2024 13:04:06 -0400 Subject: [PATCH 6/7] Use if block for better readability --- airflow/decorators/base.py | 9 +++------ airflow/models/baseoperator.py | 9 +++------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 55cd3fc9931af..38ba87d8b5d29 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -458,12 +458,9 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: if partial_kwargs.get("pool") is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME if partial_kwargs.get("pool_slots") and partial_kwargs["pool_slots"] < 1: - dag_str = ( - f""" in dag {partial_kwargs['dag'].dag_id - }""" - if partial_kwargs.get("dag") - else "" - ) + dag_str = "" + if partial_kwargs.get("dag"): + dag_str = f" in dag {partial_kwargs['dag'].dag_id}" raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES)) partial_kwargs["retry_delay"] = coerce_timedelta( diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index c334d7dd14385..ba8f61f0c8ba6 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -363,12 +363,9 @@ def partial( if partial_kwargs["pool"] is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME if partial_kwargs["pool_slots"] < 1: - dag_str = ( - f""" in dag {partial_kwargs['dag'].dag_id - }""" - if partial_kwargs["dag"] - else "" - ) + dag_str = "" + if partial_kwargs["dag"]: + dag_str = f" in dag {partial_kwargs['dag'].dag_id}" raise ValueError(f"pool slots for {partial_kwargs['task_id']}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"]) partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") From a5a76fce0adbe353038c2b71a6536c3dd308e011 Mon Sep 17 00:00:00 2001 From: Karen Braganza <158095947+karenbraganz@users.noreply.github.com> Date: Mon, 19 Aug 2024 19:30:01 -0400 Subject: [PATCH 7/7] Modify syntax and variables used --- airflow/decorators/base.py | 11 ++++++----- airflow/models/baseoperator.py | 17 ++++++++--------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 38ba87d8b5d29..bcb64aaa6eb3c 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -457,11 +457,12 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None)) if partial_kwargs.get("pool") is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME - if partial_kwargs.get("pool_slots") and partial_kwargs["pool_slots"] < 1: - dag_str = "" - if partial_kwargs.get("dag"): - dag_str = f" in dag {partial_kwargs['dag'].dag_id}" - raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") + if "pool_slots" in partial_kwargs: + if partial_kwargs["pool_slots"] < 1: + dag_str = "" + if dag: + dag_str = f" in dag {dag.dag_id}" + raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES)) partial_kwargs["retry_delay"] = coerce_timedelta( partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index ba8f61f0c8ba6..f38b45200f257 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -364,9 +364,9 @@ def partial( partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME if partial_kwargs["pool_slots"] < 1: dag_str = "" - if partial_kwargs["dag"]: - dag_str = f" in dag {partial_kwargs['dag'].dag_id}" - raise ValueError(f"pool slots for {partial_kwargs['task_id']}{dag_str} cannot be less than 1") + if dag: + dag_str = f" in dag {dag.dag_id}" + raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"]) partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") if partial_kwargs["max_retry_delay"] is not None: @@ -522,7 +522,11 @@ def __new__(cls, name, bases, namespace, **kwargs): partial_desc = vars(new_cls)["partial"] if isinstance(partial_desc, _PartialDescriptor): partial_desc.class_method = classmethod(partial) - new_cls.__init__ = cls._apply_defaults(new_cls.__init__) + + # We patch `__init__` only if the class defines it. + if inspect.getmro(new_cls)[1].__init__ is not new_cls.__init__: + new_cls.__init__ = cls._apply_defaults(new_cls.__init__) + return new_cls @@ -855,10 +859,6 @@ def say_hello_world(**context): _dag: DAG | None = None task_group: TaskGroup | None = None - # subdag parameter is only set for SubDagOperator. - # Setting it to None by default as other Operators do not have that field - subdag: DAG | None = None - start_date: pendulum.DateTime | None = None end_date: pendulum.DateTime | None = None @@ -1725,7 +1725,6 @@ def get_serialized_fields(cls): "end_date", "_task_type", "_operator_name", - "subdag", "ui_color", "ui_fgcolor", "template_ext",