From 01b224a0df7ecf852c276390d9cb815d91d78c0d Mon Sep 17 00:00:00 2001 From: Sebastien Menard Date: Mon, 12 Dec 2022 19:02:35 +0000 Subject: [PATCH] fix(ingest) - Basic support for dynamic task mapping in Airflow plugin --- .../src/datahub_provider/_plugin.py | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/metadata-ingestion/src/datahub_provider/_plugin.py b/metadata-ingestion/src/datahub_provider/_plugin.py index 7f5de876c4428a..64e5c81f46f46a 100644 --- a/metadata-ingestion/src/datahub_provider/_plugin.py +++ b/metadata-ingestion/src/datahub_provider/_plugin.py @@ -1,13 +1,12 @@ -from datahub_provider._airflow_compat import Operator +from datahub_provider._airflow_compat import BaseOperator, MappedOperator, Operator import contextlib import logging import traceback -from typing import Any, Callable, Iterable, List, Optional +from typing import Any, Callable, Iterable, List, Optional, Union from airflow.configuration import conf from airflow.lineage import PIPELINE_OUTLETS -from airflow.models.baseoperator import BaseOperator from airflow.plugins_manager import AirflowPlugin from airflow.utils.module_loading import import_string from cattr import structure @@ -290,12 +289,34 @@ def custom_on_success_callback(context): return custom_on_success_callback -def task_policy(task: BaseOperator) -> None: +def task_policy(task: Union[BaseOperator, MappedOperator]) -> None: task.log.debug(f"Setting task policy for Dag: {task.dag_id} Task: {task.task_id}") # task.add_inlets(["auto"]) # task.pre_execute = _wrap_pre_execution(task.pre_execute) - task.on_failure_callback = _wrap_on_failure_callback(task.on_failure_callback) - task.on_success_callback = _wrap_on_success_callback(task.on_success_callback) + + # MappedOperator's callbacks don't have setters until Airflow 2.X.X + # https://github.com/apache/airflow/issues/24547 + # We can bypass this by going through partial_kwargs for now + if MappedOperator and isinstance(task, MappedOperator): # type: ignore + on_failure_callback_prop: property = getattr( + MappedOperator, "on_failure_callback" + ) + on_success_callback_prop: property = getattr( + MappedOperator, "on_success_callback" + ) + if not on_failure_callback_prop.fset or not on_success_callback_prop.fset: + task.log.debug( + "Using MappedOperator's partial_kwargs instead of callback properties" + ) + task.partial_kwargs["on_failure_callback"] = _wrap_on_failure_callback( + task.on_failure_callback + ) + task.partial_kwargs["on_success_callback"] = _wrap_on_success_callback( + task.on_success_callback + ) + + task.on_failure_callback = _wrap_on_failure_callback(task.on_failure_callback) # type: ignore + task.on_success_callback = _wrap_on_success_callback(task.on_success_callback) # type: ignore # task.pre_execute = _wrap_pre_execution(task.pre_execute)