From 900ad8c1907d3342ba1777ad99db37a0d3f5d61a Mon Sep 17 00:00:00 2001 From: pegasas <616672335@qq.com> Date: Sat, 5 Aug 2023 02:16:25 +0800 Subject: [PATCH] Fix: Configurable Docker image of `xcom_sidecar` (#32858) * Configurable Docker image of xcom_sidecar * Update airflow/providers/cncf/kubernetes/utils/pod_manager.py * Update airflow/providers/cncf/kubernetes/utils/pod_manager.py * Update kubernetes.py --------- Co-authored-by: eladkal <45845474+eladkal@users.noreply.github.com> --- .../cncf/kubernetes/hooks/kubernetes.py | 18 ++++++++ .../cncf/kubernetes/operators/pod.py | 6 ++- .../cncf/kubernetes/utils/pod_manager.py | 6 +++ .../test_kubernetes_pod_operator.py | 2 + .../kubernetes/decorators/test_kubernetes.py | 11 ++++- .../cncf/kubernetes/hooks/test_kubernetes.py | 46 +++++++++++++++++++ 6 files changed, 87 insertions(+), 2 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index 56852fb1a200..ddb8cb27ad52 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -17,6 +17,7 @@ from __future__ import annotations import contextlib +import json import tempfile from functools import cached_property from typing import TYPE_CHECKING, Any, Generator @@ -99,6 +100,12 @@ def get_connection_form_widgets() -> dict[str, Any]: "cluster_context": StringField(lazy_gettext("Cluster context"), widget=BS3TextFieldWidget()), "disable_verify_ssl": BooleanField(lazy_gettext("Disable SSL")), "disable_tcp_keepalive": BooleanField(lazy_gettext("Disable TCP keepalive")), + "xcom_sidecar_container_image": StringField( + lazy_gettext("XCom sidecar image"), widget=BS3TextFieldWidget() + ), + "xcom_sidecar_container_resources": StringField( + lazy_gettext("XCom sidecar resources (JSON format)"), widget=BS3TextFieldWidget() + ), } @staticmethod @@ -356,6 +363,17 @@ def get_namespace(self) -> str | None: return self._get_field("namespace") return None + def get_xcom_sidecar_container_image(self): + """Returns the xcom sidecar image that defined in the connection.""" + return self._get_field("xcom_sidecar_container_image") + + def get_xcom_sidecar_container_resources(self): + """Returns the xcom sidecar resources that defined in the connection.""" + field = self._get_field("xcom_sidecar_container_resources") + if not field: + return None + return json.loads(field) + def get_pod_log_stream( self, pod_name: str, diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py b/airflow/providers/cncf/kubernetes/operators/pod.py index 28810b92ffda..c707f9446f16 100644 --- a/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/airflow/providers/cncf/kubernetes/operators/pod.py @@ -878,7 +878,11 @@ def build_pod_request_obj(self, context: Context | None = None) -> k8s.V1Pod: pod = secret.attach_to_pod(pod) if self.do_xcom_push: self.log.debug("Adding xcom sidecar to task %s", self.task_id) - pod = xcom_sidecar.add_xcom_sidecar(pod) + pod = xcom_sidecar.add_xcom_sidecar( + pod, + sidecar_container_image=self.hook.get_xcom_sidecar_container_image(), + sidecar_container_resources=self.hook.get_xcom_sidecar_container_resources(), + ) labels = self._get_ti_pod_labels(context) self.log.info("Building pod %s with labels: %s", pod.metadata.name, labels) diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/airflow/providers/cncf/kubernetes/utils/pod_manager.py index c8ac74382d6d..81b6c1b2ca3f 100644 --- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -101,6 +101,12 @@ def get_pod(self, name: str, namespace: str) -> V1Pod: def get_namespace(self) -> str | None: """Returns the namespace that defined in the connection.""" + def get_xcom_sidecar_container_image(self) -> str | None: + """Returns the xcom sidecar image that defined in the connection.""" + + def get_xcom_sidecar_container_resources(self) -> str | None: + """Returns the xcom sidecar resources that defined in the connection.""" + def get_container_status(pod: V1Pod, container_name: str) -> V1ContainerStatus | None: """Retrieves container status.""" diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py index 002394611d9a..7ba097f18dc0 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes_tests/test_kubernetes_pod_operator.py @@ -897,6 +897,8 @@ def test_pod_template_file( # todo: This isn't really a system test await_xcom_sidecar_container_start_mock.return_value = None hook_mock.return_value.is_in_cluster = False + hook_mock.return_value.get_xcom_sidecar_container_image.return_value = None + hook_mock.return_value.get_xcom_sidecar_container_resources.return_value = None hook_mock.return_value.get_connection.return_value = Connection(conn_id="kubernetes_default") extract_xcom_mock.return_value = "{}" path = sys.path[0] + "/tests/providers/cncf/kubernetes/pod.yaml" diff --git a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py index 9bd7c06e4002..b3ac936fdae6 100644 --- a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py +++ b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py @@ -30,6 +30,7 @@ KPO_MODULE = "airflow.providers.cncf.kubernetes.operators.pod" POD_MANAGER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager" HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.pod.KubernetesHook" +XCOM_IMAGE = "XCOM_IMAGE" @pytest.fixture(autouse=True) @@ -122,6 +123,12 @@ def f(arg1, arg2, kwarg1=None, kwarg2=None): f.override(task_id="my_task_id", do_xcom_push=True)("arg1", "arg2", kwarg1="kwarg1") + mock_hook.return_value.get_xcom_sidecar_container_image.return_value = XCOM_IMAGE + mock_hook.return_value.get_xcom_sidecar_container_resources.return_value = { + "requests": {"cpu": "1m", "memory": "10Mi"}, + "limits": {"cpu": "1m", "memory": "50Mi"}, + } + dr = dag_maker.create_dagrun() (ti,) = dr.task_instances @@ -134,6 +141,8 @@ def f(arg1, arg2, kwarg1=None, kwarg2=None): config_file="/tmp/fake_file", ) assert mock_create_pod.call_count == 1 + assert mock_hook.return_value.get_xcom_sidecar_container_image.call_count == 1 + assert mock_hook.return_value.get_xcom_sidecar_container_resources.call_count == 1 containers = mock_create_pod.call_args.kwargs["pod"].spec.containers @@ -152,7 +161,7 @@ def f(arg1, arg2, kwarg1=None, kwarg2=None): assert decoded_input == {"args": ("arg1", "arg2"), "kwargs": {"kwarg1": "kwarg1"}} # Second container is xcom image - assert containers[1].image == "alpine" + assert containers[1].image == XCOM_IMAGE assert containers[1].volume_mounts[0].mount_path == "/airflow/xcom" diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py index ba151efaf7fa..7b5428481eb7 100644 --- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py +++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py @@ -88,6 +88,20 @@ def setup_class(cls) -> None: ("disable_verify_ssl_empty", {"disable_verify_ssl": ""}), ("disable_tcp_keepalive", {"disable_tcp_keepalive": True}), ("disable_tcp_keepalive_empty", {"disable_tcp_keepalive": ""}), + ("sidecar_container_image", {"xcom_sidecar_container_image": "private.repo.com/alpine:3.16"}), + ("sidecar_container_image_empty", {"xcom_sidecar_container_image": ""}), + ( + "sidecar_container_resources", + { + "xcom_sidecar_container_resources": json.dumps( + { + "requests": {"cpu": "1m", "memory": "10Mi"}, + "limits": {"cpu": "1m", "memory": "50Mi"}, + } + ), + }, + ), + ("sidecar_container_resources_empty", {"xcom_sidecar_container_resources": ""}), ]: db.merge_conn(Connection(conn_type="kubernetes", conn_id=conn_id, extra=json.dumps(extra))) @@ -342,6 +356,38 @@ def test_get_namespace(self, conn_id, expected): "and rename _get_namespace to get_namespace." ) + @pytest.mark.parametrize( + "conn_id, expected", + ( + pytest.param("sidecar_container_image", "private.repo.com/alpine:3.16", id="sidecar-with-image"), + pytest.param("sidecar_container_image_empty", None, id="sidecar-without-image"), + ), + ) + def test_get_xcom_sidecar_container_image(self, conn_id, expected): + hook = KubernetesHook(conn_id=conn_id) + assert hook.get_xcom_sidecar_container_image() == expected + + @pytest.mark.parametrize( + "conn_id, expected", + ( + pytest.param( + "sidecar_container_resources", + { + "requests": {"cpu": "1m", "memory": "10Mi"}, + "limits": { + "cpu": "1m", + "memory": "50Mi", + }, + }, + id="sidecar-with-resources", + ), + pytest.param("sidecar_container_resources_empty", None, id="sidecar-without-resources"), + ), + ) + def test_get_xcom_sidecar_container_resources(self, conn_id, expected): + hook = KubernetesHook(conn_id=conn_id) + assert hook.get_xcom_sidecar_container_resources() == expected + @patch("kubernetes.config.kube_config.KubeConfigLoader") @patch("kubernetes.config.kube_config.KubeConfigMerger") def test_client_types(self, mock_kube_config_merger, mock_kube_config_loader):