Skip to content

Commit

Permalink
add missing read for K8S config file from conn in deferred `Kubernete…
Browse files Browse the repository at this point in the history
…sPodOperator` (#29498)


* restore convert_config_file_to_dict method and deprecate it
  • Loading branch information
hussein-awala authored Apr 22, 2023
1 parent 155ef09 commit b5296b7
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 24 deletions.
18 changes: 11 additions & 7 deletions airflow/providers/cncf/kubernetes/hooks/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,19 +467,18 @@ def _get_bool(val) -> bool | None:
class AsyncKubernetesHook(KubernetesHook):
"""Hook to use Kubernetes SDK asynchronously."""

def __init__(self, config_dict: dict | None = None, *args, **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config_dict = config_dict

self._extras: dict | None = None

async def _load_config(self):
"""Returns Kubernetes API session for use with requests"""
in_cluster = self._coalesce_param(self.in_cluster, await self._get_field("in_cluster"))
cluster_context = self._coalesce_param(self.cluster_context, await self._get_field("cluster_context"))
kubeconfig_path = self._coalesce_param(self.config_file, await self._get_field("kube_config_path"))
kubeconfig = await self._get_field("kube_config")

num_selected_configuration = len([o for o in [in_cluster, kubeconfig, self.config_dict] if o])
num_selected_configuration = len([o for o in [in_cluster, kubeconfig, kubeconfig_path] if o])

if num_selected_configuration > 1:
raise AirflowException(
Expand All @@ -494,9 +493,14 @@ async def _load_config(self):
async_config.load_incluster_config()
return async_client.ApiClient()

if self.config_dict:
self.log.debug(LOADING_KUBE_CONFIG_FILE_RESOURCE.format("config dictionary"))
await async_config.load_kube_config_from_dict(self.config_dict)
if kubeconfig_path:
self.log.debug(LOADING_KUBE_CONFIG_FILE_RESOURCE.format("kube_config"))
self._is_in_cluster = False
await async_config.load_kube_config(
config_file=kubeconfig_path,
client_configuration=self.client_configuration,
context=cluster_context,
)
return async_client.ApiClient()

if kubeconfig is not None:
Expand Down
11 changes: 7 additions & 4 deletions airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,7 @@ def __init__(
self.deferrable = deferrable
self.poll_interval = poll_interval
self.remote_pod: k8s.V1Pod | None = None

self._config_dict: dict | None = None
self._config_dict: dict | None = None # TODO: remove it when removing convert_config_file_to_dict

@cached_property
def _incluster_namespace(self):
Expand Down Expand Up @@ -572,11 +571,15 @@ def execute_async(self, context: Context):
pod_request_obj=self.pod_request_obj,
context=context,
)
self.convert_config_file_to_dict()
self.invoke_defer_method()

def convert_config_file_to_dict(self):
"""Converts passed config_file to dict format."""
warnings.warn(
"This method is deprecated and will be removed in a future version.",
DeprecationWarning,
stacklevel=2,
)
config_file = self.config_file if self.config_file else os.environ.get(KUBE_CONFIG_ENV_VAR)
if config_file:
with open(config_file) as f:
Expand All @@ -594,7 +597,7 @@ def invoke_defer_method(self):
trigger_start_time=trigger_start_time,
kubernetes_conn_id=self.kubernetes_conn_id,
cluster_context=self.cluster_context,
config_dict=self._config_dict,
config_file=self.config_file,
in_cluster=self.in_cluster,
poll_interval=self.poll_interval,
should_delete_pod=self.is_delete_operator_pod,
Expand Down
11 changes: 5 additions & 6 deletions airflow/providers/cncf/kubernetes/triggers/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ class KubernetesPodTrigger(BaseTrigger):
:param kubernetes_conn_id: The :ref:`kubernetes connection id <howto/connection:kubernetes>`
for the Kubernetes cluster.
:param cluster_context: Context that points to kubernetes cluster.
:param config_dict: Kubernetes config file content in dict format. If not specified,
default value is ``~/.kube/config``
:param config_file: Path to kubeconfig file.
:param poll_interval: Polling period in seconds to check for the status.
:param trigger_start_time: time in Datetime format when the trigger was started
:param in_cluster: run kubernetes client with in_cluster configuration.
Expand All @@ -73,7 +72,7 @@ def __init__(
kubernetes_conn_id: str | None = None,
poll_interval: float = 2,
cluster_context: str | None = None,
config_dict: dict | None = None,
config_file: str | None = None,
in_cluster: bool | None = None,
should_delete_pod: bool = True,
get_logs: bool = True,
Expand All @@ -87,7 +86,7 @@ def __init__(
self.kubernetes_conn_id = kubernetes_conn_id
self.poll_interval = poll_interval
self.cluster_context = cluster_context
self.config_dict = config_dict
self.config_file = config_file
self.in_cluster = in_cluster
self.should_delete_pod = should_delete_pod
self.get_logs = get_logs
Expand All @@ -107,7 +106,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"kubernetes_conn_id": self.kubernetes_conn_id,
"poll_interval": self.poll_interval,
"cluster_context": self.cluster_context,
"config_dict": self.config_dict,
"config_file": self.config_file,
"in_cluster": self.in_cluster,
"should_delete_pod": self.should_delete_pod,
"get_logs": self.get_logs,
Expand Down Expand Up @@ -215,7 +214,7 @@ def _get_async_hook(self) -> AsyncKubernetesHook:
self._hook = AsyncKubernetesHook(
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_dict=self.config_dict,
config_file=self.config_file,
cluster_context=self.cluster_context,
)
return self._hook
Expand Down
3 changes: 1 addition & 2 deletions tests/providers/cncf/kubernetes/operators/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,10 +1230,9 @@ def run_pod_async(self, operator: KubernetesPodOperator, map_index: int = -1):
)
return remote_pod_mock

@patch(KUB_OP_PATH.format("convert_config_file_to_dict"))
@patch(KUB_OP_PATH.format("build_pod_request_obj"))
@patch(KUB_OP_PATH.format("get_or_create_pod"))
def test_async_create_pod_should_execute_successfully(self, mocked_pod, mocked_pod_obj, mocked_conf_file):
def test_async_create_pod_should_execute_successfully(self, mocked_pod, mocked_pod_obj):
"""
Asserts that a task is deferred and the KubernetesCreatePodTrigger will be fired
when the KubernetesPodOperator is executed in deferrable mode when deferrable=True.
Expand Down
6 changes: 3 additions & 3 deletions tests/providers/cncf/kubernetes/triggers/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
CONN_ID = "test_kubernetes_conn_id"
POLL_INTERVAL = 2
CLUSTER_CONTEXT = "test-context"
CONFIG_DICT = {"a": "b"}
CONFIG_FILE = "/path/to/config/file"
IN_CLUSTER = False
SHOULD_DELETE_POD = True
GET_LOGS = True
Expand All @@ -61,7 +61,7 @@ def trigger():
kubernetes_conn_id=CONN_ID,
poll_interval=POLL_INTERVAL,
cluster_context=CLUSTER_CONTEXT,
config_dict=CONFIG_DICT,
config_file=CONFIG_FILE,
in_cluster=IN_CLUSTER,
should_delete_pod=SHOULD_DELETE_POD,
get_logs=GET_LOGS,
Expand All @@ -88,7 +88,7 @@ def test_serialize(self, trigger):
"kubernetes_conn_id": CONN_ID,
"poll_interval": POLL_INTERVAL,
"cluster_context": CLUSTER_CONTEXT,
"config_dict": CONFIG_DICT,
"config_file": CONFIG_FILE,
"in_cluster": IN_CLUSTER,
"should_delete_pod": SHOULD_DELETE_POD,
"get_logs": GET_LOGS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ def setup_method(self):
self.gke_op._cluster_url = CLUSTER_URL
self.gke_op._ssl_ca_cert = SSL_CA_CERT

@mock.patch(KUB_OP_PATH.format("convert_config_file_to_dict"))
@mock.patch.dict(os.environ, {})
@mock.patch(KUB_OP_PATH.format("build_pod_request_obj"))
@mock.patch(KUB_OP_PATH.format("get_or_create_pod"))
Expand All @@ -323,7 +322,7 @@ def setup_method(self):
)
@mock.patch(f"{GKE_OP_PATH}.fetch_cluster_info")
def test_async_create_pod_should_execute_successfully(
self, fetch_cluster_info_mock, get_con_mock, mocked_pod, mocked_pod_obj, mocked_config
self, fetch_cluster_info_mock, get_con_mock, mocked_pod, mocked_pod_obj
):
"""
Asserts that a task is deferred and the GKEStartPodTrigger will be fired
Expand Down

0 comments on commit b5296b7

Please sign in to comment.