From cb35f684d80901384f5533f7b92c62523ef1f197 Mon Sep 17 00:00:00 2001 From: Jed Cunningham Date: Mon, 2 May 2022 21:34:04 -0600 Subject: [PATCH 1/2] Fix literal cross product expansion --- airflow/models/mappedoperator.py | 5 ++++- tests/models/test_taskinstance.py | 35 +++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index aa51a73454141..b63e26ec9e48b 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -795,7 +795,10 @@ def parse_time_mapped_ti_count(self) -> Optional[int]: if not isinstance(value, MAPPABLE_LITERAL_TYPES): # None literal type encountered, so give up return None - total += len(value) + if total == 0: + total = len(value) + else: + total *= len(value) return total @cache diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index e53b52e11b1c3..074b6fc599011 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2689,6 +2689,7 @@ def show(a, b): ti.run() show_task = dag.get_task("show") + assert show_task.parse_time_mapped_ti_count is None mapped_tis, num = show_task.expand_mapped_task(dag_run.run_id, session=session) assert num == len(mapped_tis) == 4 @@ -2697,6 +2698,40 @@ def show(a, b): ti.run() assert outputs == [(1, 1), (1, 2), (2, 1), (2, 2)] + def test_map_literal_cross_product(self, dag_maker, session): + """Test a mapped task with literal cross product args expand properly.""" + outputs = [] + + with dag_maker(dag_id="product_same_types", session=session) as dag: + + @dag.task + def show(a, b): + outputs.append((a, b)) + + show.expand(a=[2, 4, 8], b=[5, 10]) + + dag_run = dag_maker.create_dagrun() + + show_task = dag.get_task("show") + assert show_task.parse_time_mapped_ti_count == 6 + mapped_tis, num = show_task.expand_mapped_task(dag_run.run_id, session=session) + assert len(mapped_tis) == 0 # Expanded at parse! + assert num == 6 + + tis = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == dag.dag_id, + TaskInstance.task_id == 'show', + TaskInstance.run_id == dag_run.run_id, + ) + .all() + ) + for ti in tis: + ti.refresh_from_task(show_task) + ti.run() + assert outputs == [(2, 5), (2, 10), (4, 5), (4, 10), (8, 5), (8, 10)] + def test_map_in_group(self, tmp_path: pathlib.Path, dag_maker, session): out = tmp_path.joinpath("out") out.touch() From 3be09f48fb7022f1638439ed989472d932c2cef2 Mon Sep 17 00:00:00 2001 From: Jed Cunningham Date: Tue, 3 May 2022 22:42:17 -0600 Subject: [PATCH 2/2] Make test non-flaky --- tests/models/test_taskinstance.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 074b6fc599011..dac987e4319d4 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2725,6 +2725,7 @@ def show(a, b): TaskInstance.task_id == 'show', TaskInstance.run_id == dag_run.run_id, ) + .order_by(TaskInstance.map_index) .all() ) for ti in tis: