diff --git a/airflow/executors/celery_kubernetes_executor.py b/airflow/executors/celery_kubernetes_executor.py index 5b934160abe1e..b1edc32235727 100644 --- a/airflow/executors/celery_kubernetes_executor.py +++ b/airflow/executors/celery_kubernetes_executor.py @@ -17,6 +17,8 @@ # under the License. from typing import Dict, List, Optional, Set, Union +from airflow.callbacks.base_callback_sink import BaseCallbackSink +from airflow.callbacks.callback_requests import CallbackRequest from airflow.configuration import conf from airflow.executors.base_executor import CommandType, EventBufferValueType, QueuedTaskInstanceType from airflow.executors.celery_executor import CeleryExecutor @@ -35,6 +37,7 @@ class CeleryKubernetesExecutor(LoggingMixin): """ supports_ad_hoc_ti_run: bool = True + callback_sink: Optional[BaseCallbackSink] = None KUBERNETES_QUEUE = conf.get('celery_kubernetes_executor', 'kubernetes_queue') @@ -204,3 +207,12 @@ def debug_dump(self) -> None: self.celery_executor.debug_dump() self.log.info("Dumping KubernetesExecutor state") self.kubernetes_executor.debug_dump() + + def send_callback(self, request: CallbackRequest) -> None: + """Sends callback for execution. + + :param request: Callback request to be executed. + """ + if not self.callback_sink: + raise ValueError("Callback sink is not ready.") + self.callback_sink.send(request) diff --git a/airflow/executors/local_kubernetes_executor.py b/airflow/executors/local_kubernetes_executor.py index cb1ddf7c9d220..9944cfe1ef1fa 100644 --- a/airflow/executors/local_kubernetes_executor.py +++ b/airflow/executors/local_kubernetes_executor.py @@ -17,6 +17,8 @@ # under the License. from typing import Dict, List, Optional, Set, Union +from airflow.callbacks.base_callback_sink import BaseCallbackSink +from airflow.callbacks.callback_requests import CallbackRequest from airflow.configuration import conf from airflow.executors.base_executor import CommandType, EventBufferValueType, QueuedTaskInstanceType from airflow.executors.kubernetes_executor import KubernetesExecutor @@ -35,6 +37,7 @@ class LocalKubernetesExecutor(LoggingMixin): """ supports_ad_hoc_ti_run: bool = True + callback_sink: Optional[BaseCallbackSink] = None KUBERNETES_QUEUE = conf.get('local_kubernetes_executor', 'kubernetes_queue') @@ -203,3 +206,12 @@ def debug_dump(self) -> None: self.local_executor.debug_dump() self.log.info("Dumping KubernetesExecutor state") self.kubernetes_executor.debug_dump() + + def send_callback(self, request: CallbackRequest) -> None: + """Sends callback for execution. + + :param request: Callback request to be executed. + """ + if not self.callback_sink: + raise ValueError("Callback sink is not ready.") + self.callback_sink.send(request) diff --git a/tests/executors/test_celery_kubernetes_executor.py b/tests/executors/test_celery_kubernetes_executor.py index 84ca14c5f08d7..5681476274cc9 100644 --- a/tests/executors/test_celery_kubernetes_executor.py +++ b/tests/executors/test_celery_kubernetes_executor.py @@ -19,6 +19,7 @@ from parameterized import parameterized +from airflow.callbacks.callback_requests import CallbackRequest from airflow.configuration import conf from airflow.executors.celery_executor import CeleryExecutor from airflow.executors.celery_kubernetes_executor import CeleryKubernetesExecutor @@ -223,3 +224,14 @@ def test_kubernetes_executor_knows_its_queue(self): assert k8s_executor_mock.kubernetes_queue == conf.get( 'celery_kubernetes_executor', 'kubernetes_queue' ) + + def test_send_callback(self): + cel_exec = CeleryExecutor() + k8s_exec = KubernetesExecutor() + cel_k8s_exec = CeleryKubernetesExecutor(cel_exec, k8s_exec) + cel_k8s_exec.callback_sink = mock.MagicMock() + + callback = CallbackRequest(full_filepath="fake") + cel_k8s_exec.send_callback(callback) + + cel_k8s_exec.callback_sink.send.assert_called_once_with(callback) diff --git a/tests/executors/test_local_kubernetes_executor.py b/tests/executors/test_local_kubernetes_executor.py index 274175f12750a..48d09ad99e5ed 100644 --- a/tests/executors/test_local_kubernetes_executor.py +++ b/tests/executors/test_local_kubernetes_executor.py @@ -17,6 +17,7 @@ # under the License. from unittest import mock +from airflow.callbacks.callback_requests import CallbackRequest from airflow.configuration import conf from airflow.executors.local_executor import LocalExecutor from airflow.executors.local_kubernetes_executor import LocalKubernetesExecutor @@ -67,3 +68,14 @@ def test_kubernetes_executor_knows_its_queue(self): LocalKubernetesExecutor(local_executor_mock, k8s_executor_mock) assert k8s_executor_mock.kubernetes_queue == conf.get('local_kubernetes_executor', 'kubernetes_queue') + + def test_send_callback(self): + local_executor_mock = mock.MagicMock() + k8s_executor_mock = mock.MagicMock() + local_k8s_exec = LocalKubernetesExecutor(local_executor_mock, k8s_executor_mock) + local_k8s_exec.callback_sink = mock.MagicMock() + + callback = CallbackRequest(full_filepath="fake") + local_k8s_exec.send_callback(callback) + + local_k8s_exec.callback_sink.send.assert_called_once_with(callback)