Skip to content

Commit

Permalink
Fix mapped tasks partial arguments when DAG default args are provided (
Browse files Browse the repository at this point in the history
…#29913)

* Add a failing test to make it pass

* use partial_kwargs when they are provide and override only None values by dag default values

* update the test and check if the values are filled in the right order

* fix overriding retry_delay with default value when it is equal to 0

* add missing default value for inlets and outlets

* set partial_kwargs dict type to dict[str, Any] and remove type ignore comments

* create a dict for default values and use NotSet instead of None to support None as accepted value

* update partial typing by removing None type from some args and set NotSet for all args

* Tweak kwarg merging slightly

This should improve iteration a bit, I think.

* Fix unit tests

---------

Co-authored-by: Tzu-ping Chung <[email protected]>
(cherry picked from commit f01051a)
  • Loading branch information
hussein-awala authored and ephraimbuddy committed Apr 14, 2023
1 parent 92a8904 commit 4fac945
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 79 deletions.
187 changes: 108 additions & 79 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET, ArgNotSet
from airflow.utils.weight_rule import WeightRule
from airflow.utils.xcom import XCOM_RETURN_KEY

Expand Down Expand Up @@ -184,50 +185,70 @@ def partial(**kwargs):
return self.class_method.__get__(cls, cls)


_PARTIAL_DEFAULTS = {
"owner": DEFAULT_OWNER,
"trigger_rule": DEFAULT_TRIGGER_RULE,
"depends_on_past": False,
"ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
"wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
"wait_for_downstream": False,
"retries": DEFAULT_RETRIES,
"queue": DEFAULT_QUEUE,
"pool_slots": DEFAULT_POOL_SLOTS,
"execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT,
"retry_delay": DEFAULT_RETRY_DELAY,
"retry_exponential_backoff": False,
"priority_weight": DEFAULT_PRIORITY_WEIGHT,
"weight_rule": DEFAULT_WEIGHT_RULE,
"inlets": [],
"outlets": [],
}


# This is what handles the actual mapping.
def partial(
operator_class: type[BaseOperator],
*,
task_id: str,
dag: DAG | None = None,
task_group: TaskGroup | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
owner: str = DEFAULT_OWNER,
email: None | str | Iterable[str] = None,
start_date: datetime | ArgNotSet = NOTSET,
end_date: datetime | ArgNotSet = NOTSET,
owner: str | ArgNotSet = NOTSET,
email: None | str | Iterable[str] | ArgNotSet = NOTSET,
params: collections.abc.MutableMapping | None = None,
resources: dict[str, Any] | None = None,
trigger_rule: str = DEFAULT_TRIGGER_RULE,
depends_on_past: bool = False,
ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
wait_for_downstream: bool = False,
retries: int | None = DEFAULT_RETRIES,
queue: str = DEFAULT_QUEUE,
pool: str | None = None,
pool_slots: int = DEFAULT_POOL_SLOTS,
execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
max_retry_delay: None | timedelta | float = None,
retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
retry_exponential_backoff: bool = False,
priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
weight_rule: str = DEFAULT_WEIGHT_RULE,
sla: timedelta | None = None,
max_active_tis_per_dag: int | None = None,
max_active_tis_per_dagrun: int | None = None,
on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
run_as_user: str | None = None,
executor_config: dict | None = None,
inlets: Any | None = None,
outlets: Any | None = None,
doc: str | None = None,
doc_md: str | None = None,
doc_json: str | None = None,
doc_yaml: str | None = None,
doc_rst: str | None = None,
resources: dict[str, Any] | None | ArgNotSet = NOTSET,
trigger_rule: str | ArgNotSet = NOTSET,
depends_on_past: bool | ArgNotSet = NOTSET,
ignore_first_depends_on_past: bool | ArgNotSet = NOTSET,
wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET,
wait_for_downstream: bool | ArgNotSet = NOTSET,
retries: int | None | ArgNotSet = NOTSET,
queue: str | ArgNotSet = NOTSET,
pool: str | ArgNotSet = NOTSET,
pool_slots: int | ArgNotSet = NOTSET,
execution_timeout: timedelta | None | ArgNotSet = NOTSET,
max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET,
retry_delay: timedelta | float | ArgNotSet = NOTSET,
retry_exponential_backoff: bool | ArgNotSet = NOTSET,
priority_weight: int | ArgNotSet = NOTSET,
weight_rule: str | ArgNotSet = NOTSET,
sla: timedelta | None | ArgNotSet = NOTSET,
max_active_tis_per_dag: int | None | ArgNotSet = NOTSET,
max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET,
on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
run_as_user: str | None | ArgNotSet = NOTSET,
executor_config: dict | None | ArgNotSet = NOTSET,
inlets: Any | None | ArgNotSet = NOTSET,
outlets: Any | None | ArgNotSet = NOTSET,
doc: str | None | ArgNotSet = NOTSET,
doc_md: str | None | ArgNotSet = NOTSET,
doc_json: str | None | ArgNotSet = NOTSET,
doc_yaml: str | None | ArgNotSet = NOTSET,
doc_rst: str | None | ArgNotSet = NOTSET,
**kwargs,
) -> OperatorPartial:
from airflow.models.dag import DagContext
Expand All @@ -242,54 +263,62 @@ def partial(
task_id = task_group.child_id(task_id)

# Merge DAG and task group level defaults into user-supplied values.
partial_kwargs, partial_params = get_merged_defaults(
dag_default_args, partial_params = get_merged_defaults(
dag=dag,
task_group=task_group,
task_params=params,
task_default_args=kwargs.pop("default_args", None),
)
partial_kwargs.update(kwargs)

# Always fully populate partial kwargs to exclude them from map().
partial_kwargs.setdefault("dag", dag)
partial_kwargs.setdefault("task_group", task_group)
partial_kwargs.setdefault("task_id", task_id)
partial_kwargs.setdefault("start_date", start_date)
partial_kwargs.setdefault("end_date", end_date)
partial_kwargs.setdefault("owner", owner)
partial_kwargs.setdefault("email", email)
partial_kwargs.setdefault("trigger_rule", trigger_rule)
partial_kwargs.setdefault("depends_on_past", depends_on_past)
partial_kwargs.setdefault("ignore_first_depends_on_past", ignore_first_depends_on_past)
partial_kwargs.setdefault("wait_for_past_depends_before_skipping", wait_for_past_depends_before_skipping)
partial_kwargs.setdefault("wait_for_downstream", wait_for_downstream)
partial_kwargs.setdefault("retries", retries)
partial_kwargs.setdefault("queue", queue)
partial_kwargs.setdefault("pool", pool)
partial_kwargs.setdefault("pool_slots", pool_slots)
partial_kwargs.setdefault("execution_timeout", execution_timeout)
partial_kwargs.setdefault("max_retry_delay", max_retry_delay)
partial_kwargs.setdefault("retry_delay", retry_delay)
partial_kwargs.setdefault("retry_exponential_backoff", retry_exponential_backoff)
partial_kwargs.setdefault("priority_weight", priority_weight)
partial_kwargs.setdefault("weight_rule", weight_rule)
partial_kwargs.setdefault("sla", sla)
partial_kwargs.setdefault("max_active_tis_per_dag", max_active_tis_per_dag)
partial_kwargs.setdefault("max_active_tis_per_dagrun", max_active_tis_per_dagrun)
partial_kwargs.setdefault("on_execute_callback", on_execute_callback)
partial_kwargs.setdefault("on_failure_callback", on_failure_callback)
partial_kwargs.setdefault("on_retry_callback", on_retry_callback)
partial_kwargs.setdefault("on_success_callback", on_success_callback)
partial_kwargs.setdefault("run_as_user", run_as_user)
partial_kwargs.setdefault("executor_config", executor_config)
partial_kwargs.setdefault("inlets", inlets or [])
partial_kwargs.setdefault("outlets", outlets or [])
partial_kwargs.setdefault("resources", resources)
partial_kwargs.setdefault("doc", doc)
partial_kwargs.setdefault("doc_json", doc_json)
partial_kwargs.setdefault("doc_md", doc_md)
partial_kwargs.setdefault("doc_rst", doc_rst)
partial_kwargs.setdefault("doc_yaml", doc_yaml)

# Create partial_kwargs from args and kwargs
partial_kwargs: dict[str, Any] = {
**kwargs,
"dag": dag,
"task_group": task_group,
"task_id": task_id,
"start_date": start_date,
"end_date": end_date,
"owner": owner,
"email": email,
"trigger_rule": trigger_rule,
"depends_on_past": depends_on_past,
"ignore_first_depends_on_past": ignore_first_depends_on_past,
"wait_for_past_depends_before_skipping": wait_for_past_depends_before_skipping,
"wait_for_downstream": wait_for_downstream,
"retries": retries,
"queue": queue,
"pool": pool,
"pool_slots": pool_slots,
"execution_timeout": execution_timeout,
"max_retry_delay": max_retry_delay,
"retry_delay": retry_delay,
"retry_exponential_backoff": retry_exponential_backoff,
"priority_weight": priority_weight,
"weight_rule": weight_rule,
"sla": sla,
"max_active_tis_per_dag": max_active_tis_per_dag,
"max_active_tis_per_dagrun": max_active_tis_per_dagrun,
"on_execute_callback": on_execute_callback,
"on_failure_callback": on_failure_callback,
"on_retry_callback": on_retry_callback,
"on_success_callback": on_success_callback,
"run_as_user": run_as_user,
"executor_config": executor_config,
"inlets": inlets,
"outlets": outlets,
"resources": resources,
"doc": doc,
"doc_json": doc_json,
"doc_md": doc_md,
"doc_rst": doc_rst,
"doc_yaml": doc_yaml,
}

# Inject DAG-level default args into args provided to this function.
partial_kwargs.update((k, v) for k, v in dag_default_args.items() if partial_kwargs.get(k) is NOTSET)

# Fill fields not provided by the user with default values.
partial_kwargs = {k: _PARTIAL_DEFAULTS.get(k) if v is NOTSET else v for k, v in partial_kwargs.items()}

# Post-process arguments. Should be kept in sync with _TaskDecorator.expand().
if "task_concurrency" in kwargs: # Reject deprecated option.
Expand Down
14 changes: 14 additions & 0 deletions tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,20 @@ def test_task_mapping_default_args():
assert mapped.start_date == pendulum.instance(default_args["start_date"])


def test_task_mapping_override_default_args():
default_args = {"retries": 2, "start_date": DEFAULT_DATE.now()}
with DAG("test-dag", start_date=DEFAULT_DATE, default_args=default_args):
literal = ["a", "b", "c"]
mapped = MockOperator.partial(task_id="task", retries=1).expand(arg2=literal)

# retries should be 1 because it is provided as a partial arg
assert mapped.partial_kwargs["retries"] == 1
# start_date should be equal to default_args["start_date"] because it is not provided as partial arg
assert mapped.start_date == pendulum.instance(default_args["start_date"])
# owner should be equal to Airflow default owner (airflow) because it is not provided at all
assert mapped.owner == "airflow"


def test_map_unknown_arg_raises():
with pytest.raises(TypeError, match=r"argument 'file'"):
BaseOperator.partial(task_id="a").expand(file=[1, 2, {"a": "b"}])
Expand Down

0 comments on commit 4fac945

Please sign in to comment.