Skip to content

Commit

Permalink
Fix literal cross product expansion (#23434)
Browse files Browse the repository at this point in the history
(cherry picked from commit 3fb8e0b)
  • Loading branch information
jedcunningham authored and ephraimbuddy committed May 8, 2022
1 parent e63f62c commit 4d1f600
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
5 changes: 4 additions & 1 deletion airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,10 @@ def parse_time_mapped_ti_count(self) -> Optional[int]:
if not isinstance(value, MAPPABLE_LITERAL_TYPES):
# None literal type encountered, so give up
return None
total += len(value)
if total == 0:
total = len(value)
else:
total *= len(value)
return total

@cache
Expand Down
36 changes: 36 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2689,6 +2689,7 @@ def show(a, b):
ti.run()

show_task = dag.get_task("show")
assert show_task.parse_time_mapped_ti_count is None
mapped_tis, num = show_task.expand_mapped_task(dag_run.run_id, session=session)
assert num == len(mapped_tis) == 4

Expand All @@ -2697,6 +2698,41 @@ def show(a, b):
ti.run()
assert outputs == [(1, 1), (1, 2), (2, 1), (2, 2)]

def test_map_literal_cross_product(self, dag_maker, session):
"""Test a mapped task with literal cross product args expand properly."""
outputs = []

with dag_maker(dag_id="product_same_types", session=session) as dag:

@dag.task
def show(a, b):
outputs.append((a, b))

show.expand(a=[2, 4, 8], b=[5, 10])

dag_run = dag_maker.create_dagrun()

show_task = dag.get_task("show")
assert show_task.parse_time_mapped_ti_count == 6
mapped_tis, num = show_task.expand_mapped_task(dag_run.run_id, session=session)
assert len(mapped_tis) == 0 # Expanded at parse!
assert num == 6

tis = (
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.task_id == 'show',
TaskInstance.run_id == dag_run.run_id,
)
.order_by(TaskInstance.map_index)
.all()
)
for ti in tis:
ti.refresh_from_task(show_task)
ti.run()
assert outputs == [(2, 5), (2, 10), (4, 5), (4, 10), (8, 5), (8, 10)]

def test_map_in_group(self, tmp_path: pathlib.Path, dag_maker, session):
out = tmp_path.joinpath("out")
out.touch()
Expand Down

0 comments on commit 4d1f600

Please sign in to comment.