Skip to content

Commit

Permalink
add template_in_template arg to expand method to tell Airflow whether…
Browse files Browse the repository at this point in the history
… to resolve the xcom data or not
  • Loading branch information
hussein-awala committed Aug 8, 2023
1 parent eb73452 commit 14bd392
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
16 changes: 11 additions & 5 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,16 @@ def __del__(self):
task_id = f"at {hex(id(self))}"
warnings.warn(f"Task {task_id} was never mapped!")

def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator:
def expand(self, template_in_template=False, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator:
if not mapped_kwargs:
raise TypeError("no arguments to expand against")
validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified")
# Since the input is already checked at parse time, we can set strict
# to False to skip the checks on execution.
return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False)
return self._expand(
DictOfListsExpandInput(mapped_kwargs), strict=False, template_in_template=template_in_template
)

def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator:
from airflow.models.xcom_arg import XComArg
Expand All @@ -182,7 +184,9 @@ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool =
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)

def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
def _expand(
self, expand_input: ExpandInput, *, strict: bool, template_in_template: bool = False
) -> MappedOperator:
from airflow.operators.empty import EmptyOperator

self._expand_called = True
Expand Down Expand Up @@ -225,6 +229,7 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
# For classic operators, this points to expand_input because kwargs
# to BaseOperator.expand() contribute to operator arguments.
expand_input_attr="expand_input",
template_in_template=template_in_template,
)
return op

Expand Down Expand Up @@ -261,6 +266,7 @@ class MappedOperator(AbstractOperator):
template_ext: Sequence[str]
template_fields: Collection[str]
template_fields_renderers: dict[str, str]
template_in_template: bool = False
ui_color: str
ui_fgcolor: str
_is_empty: bool
Expand Down Expand Up @@ -721,7 +727,7 @@ def render_template_fields(
# in the weeds here. We don't close this session for the same reason.
session = settings.Session()

mapped_kwargs, _ = self._expand_mapped_kwargs(context, session)
mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session)
unmapped_task = self.unmap(mapped_kwargs)
context_update_for_unmapped(context, unmapped_task)

Expand All @@ -734,6 +740,6 @@ def render_template_fields(
template_fields=self.template_fields,
context=context,
jinja_env=jinja_env,
seen_oids=set(),
seen_oids=seen_oids if not self.template_in_template else set(),
session=session,
)
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
# Only on MappedOperator.
"expand_input",
"partial_kwargs",
"template_in_template",
}


Expand Down
11 changes: 7 additions & 4 deletions tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,8 @@ def test_mapped_expand_against_params(dag_maker, dag_params, task_params, expect
assert t.expand_input.value == {"params": [{"c": "x"}, {"d": 1}]}


def test_mapped_render_template_fields_validating_operator(dag_maker, session):
@pytest.mark.parametrize("template_in_template", [True, False])
def test_mapped_render_template_fields_validating_operator(dag_maker, session, template_in_template):
class MyOperator(BaseOperator):
template_fields = ("partial_template", "map_template")

Expand All @@ -419,7 +420,7 @@ def execute(self, context):
output1 = task1.output
mapped = MyOperator.partial(
task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}"
).expand(map_template=output1, map_static=output1)
).expand(template_in_template=template_in_template, map_template=output1, map_static=output1)

dr = dag_maker.create_dagrun()
ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
Expand Down Expand Up @@ -447,7 +448,9 @@ def execute(self, context):

assert mapped_ti.task.partial_template == "a", "Should be templated!"
assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!"
assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!"
assert (
mapped_ti.task.map_template == "2016-01-01" if template_in_template else "{{ ds }}"
), "Should be templated only when template_in_template is set to True!"
assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!"


Expand Down Expand Up @@ -545,7 +548,7 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis
@pytest.mark.parametrize(
"map_index, expected",
[
pytest.param(0, "2016-01-01", id="0"),
pytest.param(0, "{{ ds }}", id="0"),
pytest.param(1, 2, id="1"),
],
)
Expand Down

0 comments on commit 14bd392

Please sign in to comment.