diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index aa51a73454141..b63e26ec9e48b 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -795,7 +795,10 @@ def parse_time_mapped_ti_count(self) -> Optional[int]: if not isinstance(value, MAPPABLE_LITERAL_TYPES): # None literal type encountered, so give up return None - total += len(value) + if total == 0: + total = len(value) + else: + total *= len(value) return total @cache diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index e53b52e11b1c3..dac987e4319d4 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2689,6 +2689,7 @@ def show(a, b): ti.run() show_task = dag.get_task("show") + assert show_task.parse_time_mapped_ti_count is None mapped_tis, num = show_task.expand_mapped_task(dag_run.run_id, session=session) assert num == len(mapped_tis) == 4 @@ -2697,6 +2698,41 @@ def show(a, b): ti.run() assert outputs == [(1, 1), (1, 2), (2, 1), (2, 2)] + def test_map_literal_cross_product(self, dag_maker, session): + """Test a mapped task with literal cross product args expand properly.""" + outputs = [] + + with dag_maker(dag_id="product_same_types", session=session) as dag: + + @dag.task + def show(a, b): + outputs.append((a, b)) + + show.expand(a=[2, 4, 8], b=[5, 10]) + + dag_run = dag_maker.create_dagrun() + + show_task = dag.get_task("show") + assert show_task.parse_time_mapped_ti_count == 6 + mapped_tis, num = show_task.expand_mapped_task(dag_run.run_id, session=session) + assert len(mapped_tis) == 0 # Expanded at parse! + assert num == 6 + + tis = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == dag.dag_id, + TaskInstance.task_id == 'show', + TaskInstance.run_id == dag_run.run_id, + ) + .order_by(TaskInstance.map_index) + .all() + ) + for ti in tis: + ti.refresh_from_task(show_task) + ti.run() + assert outputs == [(2, 5), (2, 10), (4, 5), (4, 10), (8, 5), (8, 10)] + def test_map_in_group(self, tmp_path: pathlib.Path, dag_maker, session): out = tmp_path.joinpath("out") out.touch()