From 164526d4c798a72dba3087d71f30f60f60595b0e Mon Sep 17 00:00:00 2001 From: Changhoon Oh <81631424+okayhooni@users.noreply.github.com> Date: Sat, 5 Aug 2023 03:41:21 +0900 Subject: [PATCH] Consider custom pod labels on pod finding process on `KubernetesPodOperator` (#33057) * consider custom pod labels on pod finding process on KubernetesPodOperator --------- Co-authored-by: eladkal <45845474+eladkal@users.noreply.github.com> --- airflow/providers/cncf/kubernetes/operators/pod.py | 5 ++++- .../providers/cncf/kubernetes/operators/test_pod.py | 12 +++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py b/airflow/providers/cncf/kubernetes/operators/pod.py index 0b6c90c987dd..d72b2fec0c28 100644 --- a/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/airflow/providers/cncf/kubernetes/operators/pod.py @@ -762,7 +762,10 @@ def process_pod_deletion(self, pod: k8s.V1Pod, *, reraise=True): self.log.info("Skipping deleting pod: %s", pod.metadata.name) def _build_find_pod_label_selector(self, context: Context | None = None, *, exclude_checked=True) -> str: - labels = self._get_ti_pod_labels(context, include_try_number=False) + labels = { + **self.labels, + **self._get_ti_pod_labels(context, include_try_number=False), + } label_strings = [f"{label_id}={label}" for label_id, label in sorted(labels.items())] labels_value = ",".join(label_strings) if exclude_checked: diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py b/tests/providers/cncf/kubernetes/operators/test_pod.py index bcd2bd0a31f2..2c6dc2188a7d 100644 --- a/tests/providers/cncf/kubernetes/operators/test_pod.py +++ b/tests/providers/cncf/kubernetes/operators/test_pod.py @@ -312,6 +312,16 @@ def test_labels_mapped(self): "airflow_kpo_in_cluster": str(k.hook.is_in_cluster), } + def test_find_custom_pod_labels(self): + k = KubernetesPodOperator( + labels={"foo": "bar", "hello": "airflow"}, + name="test", + task_id="task", + ) + context = create_context(k) + label_selector = k._build_find_pod_label_selector(context) + assert "foo=bar" in label_selector and "hello=airflow" in label_selector + @patch(HOOK_CLASS, new=MagicMock) def test_find_pod_labels(self): k = KubernetesPodOperator( @@ -327,7 +337,7 @@ def test_find_pod_labels(self): self.run_pod(k) _, kwargs = k.client.list_namespaced_pod.call_args assert kwargs["label_selector"] == ( - "dag_id=dag,kubernetes_pod_operator=True,run_id=test,task_id=task," + "dag_id=dag,foo=bar,kubernetes_pod_operator=True,run_id=test,task_id=task," "already_checked!=True,!airflow-worker" )