diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py b/airflow/providers/google/cloud/operators/kubernetes_engine.py index 9ee718165ab34..83c013ba441fc 100644 --- a/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -21,7 +21,8 @@ import os import tempfile import warnings -from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union +from contextlib import contextmanager +from typing import TYPE_CHECKING, Dict, Generator, Optional, Sequence, Union from google.cloud.container_v1.types import Cluster @@ -336,11 +337,22 @@ def __init__( if self.config_file: raise AirflowException("config_file is not an allowed parameter for the GKEStartPodOperator.") - def execute(self, context: 'Context') -> Optional[str]: - hook = GoogleBaseHook(gcp_conn_id=self.gcp_conn_id) - self.project_id = self.project_id or hook.project_id + @staticmethod + @contextmanager + def get_gke_config_file( + gcp_conn_id, + project_id: Optional[str], + cluster_name: str, + impersonation_chain: Optional[Union[str, Sequence[str]]], + regional: bool, + location: str, + use_internal_ip: bool, + ) -> Generator[str, None, None]: - if not self.project_id: + hook = GoogleBaseHook(gcp_conn_id=gcp_conn_id) + project_id = project_id or hook.project_id + + if not project_id: raise AirflowException( "The project id must be passed either as " "keyword project_id parameter or as project_id extra " @@ -363,15 +375,15 @@ def execute(self, context: 'Context') -> Optional[str]: "container", "clusters", "get-credentials", - self.cluster_name, + cluster_name, "--project", - self.project_id, + project_id, ] - if self.impersonation_chain: - if isinstance(self.impersonation_chain, str): - impersonation_account = self.impersonation_chain - elif len(self.impersonation_chain) == 1: - impersonation_account = self.impersonation_chain[0] + if impersonation_chain: + if isinstance(impersonation_chain, str): + impersonation_account = impersonation_chain + elif len(impersonation_chain) == 1: + impersonation_account = impersonation_chain[0] else: raise AirflowException( "Chained list of accounts is not supported, please specify only one service account" @@ -383,15 +395,28 @@ def execute(self, context: 'Context') -> Optional[str]: impersonation_account, ] ) - if self.regional: + if regional: cmd.append('--region') else: cmd.append('--zone') - cmd.append(self.location) - if self.use_internal_ip: + cmd.append(location) + if use_internal_ip: cmd.append('--internal-ip') execute_in_subprocess(cmd) # Tell `KubernetesPodOperator` where the config file is located - self.config_file = os.environ[KUBE_CONFIG_ENV_VAR] + yield os.environ[KUBE_CONFIG_ENV_VAR] + + def execute(self, context: 'Context') -> Optional[str]: + + with GKEStartPodOperator.get_gke_config_file( + gcp_conn_id=self.gcp_conn_id, + project_id=self.project_id, + cluster_name=self.cluster_name, + impersonation_chain=self.impersonation_chain, + regional=self.regional, + location=self.location, + use_internal_ip=self.use_internal_ip, + ) as config_file: + self.config_file = config_file return super().execute(context)