From bc5adae3377d1578249667a2f5d522b167b8f340 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 21 Jan 2022 15:04:40 +0000 Subject: [PATCH 1/7] Function to expand mapped tasks in to multiple TIs --- airflow/models/baseoperator.py | 48 ++++++++++++++++++++++ airflow/models/taskinstance.py | 3 +- tests/models/test_baseoperator.py | 66 ++++++++++++++++++++++++++++++- tests/models/test_dagrun.py | 18 +++++++++ 4 files changed, 132 insertions(+), 3 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index dbe63b4e9c962..2f1d80498b63d 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -75,6 +75,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.operator_resources import Resources from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.state import TaskInstanceState from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule @@ -1800,6 +1801,53 @@ def wait_for_downstream(self) -> bool: def depends_on_past(self) -> bool: return self.partial_kwargs.get("depends_on_past") or self.wait_for_downstream + def expand_mapped_task(self, upstream_ti: "TaskInstance", session: "Session" = NEW_SESSION) -> None: + """Create the mapped TaskInstances for mapped task.""" + # TODO: support having multiuple mapped upstreams? + from airflow.models.taskmap import TaskMap + from airflow.settings import task_instance_mutation_hook + + task_map_info: TaskMap = ( + session.query(TaskMap) + .filter_by( + dag_id=upstream_ti.dag_id, + task_id=upstream_ti.task_id, + run_id=upstream_ti.run_id, + map_index=upstream_ti.map_index, + ) + .one() + ) + + unmapped_ti: Optional[TaskInstance] = upstream_ti.dag_run.get_task_instance( + self.task_id, map_index=-1, session=session + ) + + maps = range(task_map_info.length) + + if unmapped_ti: + # The unmapped TaskInstance still exisxts -- this means we haven't + # tried to run it before. + unmapped_ti.map_index = 0 + maps = range(1, task_map_info.length) + + for index in maps: + # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings + # TODO: Change `TaskInstance` ctor to take Operator, not BaseOperator + ti = TaskInstance(self, run_id=upstream_ti.run_id, map_index=index) # type: ignore + task_instance_mutation_hook(ti) + session.merge(ti) + + # Set to "REMOVED" any (old) TaskInstances with map indices greater + # than the current map value + session.query(TaskInstance).filter( + TaskInstance.dag_id == upstream_ti.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.run_id == upstream_ti.run_id, + TaskInstance.map_index >= task_map_info.length, + ).update({TaskInstance.state: TaskInstanceState.REMOVED}) + + session.flush() + # TODO: Deprecate for Airflow 3.0 Chainable = Union[DependencyMixin, Sequence[DependencyMixin]] diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 075027ce81566..f1f1b93a63157 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -486,7 +486,7 @@ def __init__( self.test_mode = False @staticmethod - def insert_mapping(run_id: str, task: "BaseOperator") -> dict: + def insert_mapping(run_id: str, task: "BaseOperator", map_index: int = -1) -> dict: """:meta private:""" return { 'dag_id': task.dag_id, @@ -503,6 +503,7 @@ def insert_mapping(run_id: str, task: "BaseOperator") -> dict: 'max_tries': task.retries, 'executor_config': task.executor_config, 'operator': task.task_type, + 'map_index': map_index, } @reconstructor diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index ee79aa9a75bdb..1311e2e99acd1 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -36,8 +36,12 @@ chain, cross_downstream, ) +from airflow.models.taskinstance import TaskInstance +from airflow.models.taskmap import TaskMap +from airflow.models.xcom_arg import XComArg from airflow.utils.context import Context from airflow.utils.edgemodifier import Label +from airflow.utils.state import TaskInstanceState from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule @@ -733,8 +737,6 @@ def test_map_unknown_arg_raises(): def test_map_xcom_arg(): """Test that dependencies are correct when mapping with an XComArg""" - from airflow.models.xcom_arg import XComArg - with DAG("test-dag", start_date=DEFAULT_DATE): task1 = BaseOperator(task_id="op1") xcomarg = XComArg(task1, "test_key") @@ -767,3 +769,63 @@ def test_partial_on_class_invalid_ctor_args() -> None: """ with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"): MockOperator.partial(task_id='a', foo='bar', bar=2) + + +@pytest.mark.parametrize( + ["num_existing_tis", "expected"], + ( + pytest.param(0, [(0, None), (1, None), (2, None)], id='only-unmapped-ti-exists'), + pytest.param(3, [(0, None), (1, None), (2, None)], id='all-tis-exist'), + pytest.param( + 5, + [(0, None), (1, None), (2, None), (3, TaskInstanceState.REMOVED), (4, TaskInstanceState.REMOVED)], + id="tis-to-be-remove", + ), + ), +) +def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expected): + literal = [1, 2, {'a': 'b'}] + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + xcomarg = XComArg(task1, "test_key") + mapped = MockOperator(task_id='task_2').map(arg2=xcomarg) + + dr = dag_maker.create_dagrun() + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=len(literal), + keys=None, + ) + ) + + if num_existing_tis: + # Remove the map_index=-1 TI when we're creating other TIs + session.query(TaskInstance).filter( + TaskInstance.dag_id == mapped.dag_id, + TaskInstance.task_id == mapped.task_id, + TaskInstance.run_id == dr.run_id, + ).delete() + + for index in range(num_existing_tis): + ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index) + session.add(ti) + session.flush() + + mapped.expand_mapped_task( + upstream_ti=dr.get_task_instance(task1.task_id), + session=session, + ) + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + + assert indices == expected diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 56f214009f111..c5f048554b0dc 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -39,6 +39,7 @@ from airflow.utils.types import DagRunType from tests.models import DEFAULT_DATE from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs +from tests.test_utils.mock_operators import MockOperator class TestDagRun(unittest.TestCase): @@ -874,3 +875,20 @@ def test_verify_integrity_task_start_date(Stats_incr, session, run_type, expecte assert len(tis) == expected_tis Stats_incr.assert_called_with('task_instance_created-DummyOperator', expected_tis) + + +@pytest.mark.xfail(reason="TODO: Expand mapped literals at verify_integrity time!") +def test_expand_mapped_task_instance(dag_maker, session): + literal = [1, 2, {'a': 'b'}] + with dag_maker(session=session): + mapped = MockOperator(task_id='task_2').map(arg2=literal) + + dr = dag_maker.create_dagrun() + indices = ( + session.query(TI.map_index) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TI.insert_mapping) + .all() + ) + + assert indices == [0, 1, 2] From d3b9fdc6406421f89a061698a052def32e6df732 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 24 Jan 2022 21:21:03 +0000 Subject: [PATCH 2/7] Update airflow/models/baseoperator.py Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> --- airflow/models/baseoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 2f1d80498b63d..0c0798da7cf85 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1825,7 +1825,7 @@ def expand_mapped_task(self, upstream_ti: "TaskInstance", session: "Session" = N maps = range(task_map_info.length) if unmapped_ti: - # The unmapped TaskInstance still exisxts -- this means we haven't + # The unmapped TaskInstance still exists -- this means we haven't # tried to run it before. unmapped_ti.map_index = 0 maps = range(1, task_map_info.length) From 6878471ae370688f4d7aa51f53b49423e5cd5f0e Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 26 Jan 2022 14:01:10 +0800 Subject: [PATCH 3/7] Mark unmapped ti as SKIPPED if upstream is empty --- airflow/models/baseoperator.py | 48 ++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 0c0798da7cf85..9c22d21d39bd8 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -75,7 +75,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.operator_resources import Resources from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.state import TaskInstanceState +from airflow.utils.state import State, TaskInstanceState from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule @@ -1807,31 +1807,51 @@ def expand_mapped_task(self, upstream_ti: "TaskInstance", session: "Session" = N from airflow.models.taskmap import TaskMap from airflow.settings import task_instance_mutation_hook - task_map_info: TaskMap = ( - session.query(TaskMap) + task_map_info_length: Optional[int] = ( + session.query(TaskMap.length) .filter_by( dag_id=upstream_ti.dag_id, task_id=upstream_ti.task_id, run_id=upstream_ti.run_id, map_index=upstream_ti.map_index, ) - .one() + .scalar() ) + if task_map_info_length is None: + # TODO: What would lead to this? How can this be better handled? + raise RuntimeError("mapped operator cannot be expanded; upstream not found") + # TODO: Add db constraint to ensure this is never negative. - unmapped_ti: Optional[TaskInstance] = upstream_ti.dag_run.get_task_instance( - self.task_id, map_index=-1, session=session + unmapped_ti: Optional[TaskInstance] = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == upstream_ti.dag_id, + TaskInstance.run_id == upstream_ti.run_id, + TaskInstance.task_id == self.task_id, + TaskInstance.map_index == -1, + TaskInstance.state.in_(State.unfinished), + ) + .one_or_none() ) - maps = range(task_map_info.length) - if unmapped_ti: - # The unmapped TaskInstance still exists -- this means we haven't - # tried to run it before. + # The unmapped task instance still exists and is unfinished, i.e. we + # haven't tried to run it before. + if task_map_info_length < 1: + # If the upstream maps this to a zero-length value, simply marked the + # unmapped task instance as SKIPPED (if needed). + unmapped_ti.state = TaskInstanceState.SKIPPED + session.merge(unmapped_ti) + return + # Otherwise convert this into the first mapped index, and create + # TaskInstance for other indexes. unmapped_ti.map_index = 0 - maps = range(1, task_map_info.length) + indexes_to_map = range(1, task_map_info_length) + else: + indexes_to_map = range(task_map_info_length) - for index in maps: - # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings + for index in indexes_to_map: + # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. # TODO: Change `TaskInstance` ctor to take Operator, not BaseOperator ti = TaskInstance(self, run_id=upstream_ti.run_id, map_index=index) # type: ignore task_instance_mutation_hook(ti) @@ -1843,7 +1863,7 @@ def expand_mapped_task(self, upstream_ti: "TaskInstance", session: "Session" = N TaskInstance.dag_id == upstream_ti.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.run_id == upstream_ti.run_id, - TaskInstance.map_index >= task_map_info.length, + TaskInstance.map_index >= task_map_info_length, ).update({TaskInstance.state: TaskInstanceState.REMOVED}) session.flush() From b7a5779f10603822ae8aa24b4fdd9436168daf9b Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 26 Jan 2022 09:26:56 +0000 Subject: [PATCH 4/7] Require map_index --- airflow/models/dagrun.py | 2 +- airflow/models/taskinstance.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 81d26bc490d5e..4f8c49e35e5de 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -809,7 +809,7 @@ def task_filter(task: "BaseOperator"): def create_ti_mapping(task: "BaseOperator"): created_counts[task.task_type] += 1 - return TI.insert_mapping(self.run_id, task) + return TI.insert_mapping(self.run_id, task, map_index=-1) else: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index f1f1b93a63157..a5d1c12b06727 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -486,7 +486,7 @@ def __init__( self.test_mode = False @staticmethod - def insert_mapping(run_id: str, task: "BaseOperator", map_index: int = -1) -> dict: + def insert_mapping(run_id: str, task: "BaseOperator", map_index: int) -> dict: """:meta private:""" return { 'dag_id': task.dag_id, From 0d32f72c48355646f82158dd8dbbe37f3a626b22 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 26 Jan 2022 09:30:52 +0000 Subject: [PATCH 5/7] Log when skipping task due to 0-length TaskMapInfo --- airflow/models/baseoperator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 9c22d21d39bd8..fd8db5be5cde1 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1840,6 +1840,7 @@ def expand_mapped_task(self, upstream_ti: "TaskInstance", session: "Session" = N if task_map_info_length < 1: # If the upstream maps this to a zero-length value, simply marked the # unmapped task instance as SKIPPED (if needed). + self.log.info("Marking %s as SKIPPED since the map has 0 values to expand", unmapped_ti) unmapped_ti.state = TaskInstanceState.SKIPPED session.merge(unmapped_ti) return From aeb7dfed161389e88b620f1fb5e851fa3dd5f974 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 26 Jan 2022 09:43:28 +0000 Subject: [PATCH 6/7] fixup! Mark unmapped ti as SKIPPED if upstream is empty --- airflow/models/baseoperator.py | 5 +++-- tests/models/test_baseoperator.py | 31 ++++++++++++++++++++++++++----- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index fd8db5be5cde1..1ee2b1b044543 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -50,6 +50,7 @@ import jinja2 import pendulum from dateutil.relativedelta import relativedelta +from sqlalchemy import or_ from sqlalchemy.orm import Session from sqlalchemy.orm.exc import NoResultFound @@ -1829,7 +1830,7 @@ def expand_mapped_task(self, upstream_ti: "TaskInstance", session: "Session" = N TaskInstance.run_id == upstream_ti.run_id, TaskInstance.task_id == self.task_id, TaskInstance.map_index == -1, - TaskInstance.state.in_(State.unfinished), + or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), ) .one_or_none() ) @@ -1842,7 +1843,7 @@ def expand_mapped_task(self, upstream_ti: "TaskInstance", session: "Session" = N # unmapped task instance as SKIPPED (if needed). self.log.info("Marking %s as SKIPPED since the map has 0 values to expand", unmapped_ti) unmapped_ti.state = TaskInstanceState.SKIPPED - session.merge(unmapped_ti) + session.flush() return # Otherwise convert this into the first mapped index, and create # TaskInstance for other indexes. diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 1311e2e99acd1..a78ae4d9f5586 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -812,14 +812,11 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec ).delete() for index in range(num_existing_tis): - ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index) + ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index) # type: ignore session.add(ti) session.flush() - mapped.expand_mapped_task( - upstream_ti=dr.get_task_instance(task1.task_id), - session=session, - ) + mapped.expand_mapped_task(upstream_ti=dr.get_task_instance(task1.task_id), session=session) indices = ( session.query(TaskInstance.map_index, TaskInstance.state) @@ -829,3 +826,27 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec ) assert indices == expected + + +def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + xcomarg = XComArg(task1, "test_key") + mapped = MockOperator(task_id='task_2').map(arg2=xcomarg) + + dr = dag_maker.create_dagrun() + + session.add( + TaskMap(dag_id=dr.dag_id, task_id=task1.task_id, run_id=dr.run_id, map_index=-1, length=0, keys=None) + ) + + mapped.expand_mapped_task(upstream_ti=dr.get_task_instance(task1.task_id), session=session) + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + + assert indices == [(-1, TaskInstanceState.SKIPPED)] From b331ae25db9cc1102a0706d9e470f2c9ad15c202 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 26 Jan 2022 09:44:52 +0000 Subject: [PATCH 7/7] fixup! Mark unmapped ti as SKIPPED if upstream is empty --- airflow/models/baseoperator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 1ee2b1b044543..d8a1bcef358c3 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1821,7 +1821,6 @@ def expand_mapped_task(self, upstream_ti: "TaskInstance", session: "Session" = N if task_map_info_length is None: # TODO: What would lead to this? How can this be better handled? raise RuntimeError("mapped operator cannot be expanded; upstream not found") - # TODO: Add db constraint to ensure this is never negative. unmapped_ti: Optional[TaskInstance] = ( session.query(TaskInstance)