Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rendering the mapped parameters when using expand_kwargs method #32272

Merged
merged 9 commits into from
Aug 18, 2023
5 changes: 4 additions & 1 deletion airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,10 @@ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any]
f"expand_kwargs() input dict keys must all be str, "
f"but {key!r} is of type {_describe_type(key)}"
)
return mapping, {id(v) for v in mapping.values()}
# filter out parse time resolved values from the resolved_oids
resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)}

return mapping, resolved_oids


EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.
Expand Down
89 changes: 77 additions & 12 deletions tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
from collections import defaultdict
from datetime import timedelta
from unittest import mock
from unittest.mock import patch

import pendulum
Expand Down Expand Up @@ -399,17 +400,31 @@ def test_mapped_expand_against_params(dag_maker, dag_params, task_params, expect


def test_mapped_render_template_fields_validating_operator(dag_maker, session):
class MyOperator(MockOperator):
def __init__(self, value, arg1, **kwargs):
assert isinstance(value, str), "value should have been resolved before unmapping"
assert isinstance(arg1, str), "value should have been resolved before unmapping"
super().__init__(arg1=arg1, **kwargs)
self.value = value
class MyOperator(BaseOperator):
template_fields = ("partial_template", "map_template", "file_template")
template_ext = (".ext",)

def __init__(
self, partial_template, partial_static, map_template, map_static, file_template, **kwargs
):
for value in [partial_template, partial_static, map_template, map_static, file_template]:
assert isinstance(value, str), "value should have been resolved before unmapping"
super().__init__(**kwargs)
self.partial_template = partial_template
self.partial_static = partial_static
self.map_template = map_template
self.map_static = map_static
self.file_template = file_template

def execute(self, context):
pass

with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
output1 = task1.output
mapped = MyOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand(value=output1, arg1=output1)
mapped = MyOperator.partial(
task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}"
).expand(map_template=output1, map_static=output1, file_template=["/path/to/file.ext"])

dr = dag_maker.create_dagrun()
ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
Expand All @@ -432,12 +447,62 @@ def __init__(self, value, arg1, **kwargs):
mapped_ti.map_index = 0

assert isinstance(mapped_ti.task, MappedOperator)
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), patch(
"os.path.isfile", return_value=True
), patch("os.path.getmtime", return_value=0):
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
assert isinstance(mapped_ti.task, MyOperator)

assert mapped_ti.task.partial_template == "a", "Should be templated!"
assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!"
assert mapped_ti.task.map_template == "{{ ds }}", "Should not be templated!"
assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!"
assert mapped_ti.task.file_template == "loaded data", "Should be templated!"


def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session):
class MyOperator(BaseOperator):
template_fields = ("partial_template", "map_template", "file_template")
template_ext = (".ext",)

def __init__(
self, partial_template, partial_static, map_template, map_static, file_template, **kwargs
):
for value in [partial_template, partial_static, map_template, map_static, file_template]:
assert isinstance(value, str), "value should have been resolved before unmapping"
super().__init__(**kwargs)
self.partial_template = partial_template
self.partial_static = partial_static
self.map_template = map_template
self.map_static = map_static
self.file_template = file_template

def execute(self, context):
pass

with dag_maker(session=session):
mapped = MyOperator.partial(
task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}"
).expand_kwargs(
[{"map_template": "{{ ds }}", "map_static": "{{ ds }}", "file_template": "/path/to/file.ext"}]
)

dr = dag_maker.create_dagrun()

mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session, map_index=0)

assert isinstance(mapped_ti.task, MappedOperator)
with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), patch(
"os.path.isfile", return_value=True
), patch("os.path.getmtime", return_value=0):
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
assert isinstance(mapped_ti.task, MyOperator)

assert mapped_ti.task.value == "{{ ds }}", "Should not be templated!"
assert mapped_ti.task.arg1 == "{{ ds }}", "Should not be templated!"
assert mapped_ti.task.arg2 == "a"
assert mapped_ti.task.partial_template == "a", "Should be templated!"
assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!"
assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!"
assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!"
assert mapped_ti.task.file_template == "loaded data", "Should be templated!"


def test_mapped_render_nested_template_fields(dag_maker, session):
Expand Down Expand Up @@ -534,7 +599,7 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis
@pytest.mark.parametrize(
"map_index, expected",
[
pytest.param(0, "{{ ds }}", id="0"),
pytest.param(0, "2016-01-01", id="0"),
pytest.param(1, 2, id="1"),
],
)
Expand Down