diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index 91656879468c7..576ff80e58de2 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -61,15 +61,45 @@ class ResourceVersion: - """Singleton for tracking resourceVersion from Kubernetes""" + """ + Track resourceVersion from Kubernetes + + All instances of this class share the same state + """ + + _shared_state = {} + + def __init__( + self, + *, + kube_client: client.CoreV1Api = None, + namespace: str = None, + resource_version: Optional[str] = None, + ): + self.__dict__ = self._shared_state + if resource_version: + # Update the state + self.resource_version = resource_version + if not hasattr(self, 'resource_version'): + if not (kube_client and namespace): + raise AirflowException("kube_client and namespace is required to get resource version") + re_version = get_latest_resource_version(kube_client, namespace) + self._shared_state.update(resource_version=re_version) - _instance = None - resource_version = "0" + @classmethod + def _drop(cls): + """Clear shared state (For testing purposes)""" + cls._shared_state = {} - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance + +def get_latest_resource_version(kube_client: client.CoreV1Api, namespace: str) -> None: + """ + List pods to get the latest resource version + + See https://kubernetes.io/docs/reference/using-api/api-concepts/#efficient-detection-of-changes + """ + pod_list = kube_client.list_namespaced_pod(namespace) + return pod_list.metadata.resource_version class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin): @@ -80,7 +110,7 @@ def __init__( namespace: Optional[str], multi_namespace_mode: bool, watcher_queue: 'Queue[KubernetesWatchType]', - resource_version: Optional[str], + resource_version: str, scheduler_job_id: Optional[str], kube_config: Configuration, ): @@ -102,6 +132,22 @@ def run(self) -> None: self.resource_version = self._run( kube_client, self.resource_version, self.scheduler_job_id, self.kube_config ) + except ApiException as err: + if err.status == 410: + self.log.info( + "KubernetesJobWatcher encountered an error, error code: %s, reason: %s", + err.status, + err.reason, + ) + self.log.info("Relisting pod to get the latest resource version") + self.resource_version = get_latest_resource_version(kube_client, self.namespace) + else: + self.log.exception( + 'KubernetesJobWatcher encountered an error, failing, error code: %s, reason: %s', + err.status, + err.reason, + ) + raise except ReadTimeoutError: self.log.warning( "There was a timeout error accessing the Kube API. Retrying request.", exc_info=True @@ -119,21 +165,22 @@ def run(self) -> None: def _run( self, kube_client: client.CoreV1Api, - resource_version: Optional[str], + resource_version: str, scheduler_job_id: str, kube_config: Any, ) -> Optional[str]: self.log.info('Event: and now my watch begins starting at resource_version: %s', resource_version) watcher = watch.Watch() - kwargs = {'label_selector': f'airflow-worker={scheduler_job_id}'} - if resource_version: - kwargs['resource_version'] = resource_version + kwargs = { + 'label_selector': f'airflow-worker={scheduler_job_id}', + 'resource_version': resource_version, + } if kube_config.kube_client_request_args: for key, value in kube_config.kube_client_request_args.items(): kwargs[key] = value - last_resource_version: Optional[str] = None + last_resource_version: str = resource_version if self.multi_namespace_mode: list_worker_pods = functools.partial( watcher.stream, kube_client.list_pod_for_all_namespaces, **kwargs @@ -146,7 +193,7 @@ def _run( task = event['object'] self.log.info('Event: %s had an event of type %s', task.metadata.name, event['type']) if event['type'] == 'ERROR': - return self.process_error(event) + return self.process_error(event, kube_client) annotations = task.metadata.annotations task_instance_related_annotations = { 'dag_id': annotations['dag_id'], @@ -154,7 +201,6 @@ def _run( 'execution_date': annotations['execution_date'], 'try_number': annotations['try_number'], } - self.process_status( pod_id=task.metadata.name, namespace=task.metadata.namespace, @@ -167,16 +213,21 @@ def _run( return last_resource_version - def process_error(self, event: Any) -> str: + def process_error( + self, + event: Any, + kube_client: client.CoreV1Api, + ) -> str: """Process error response""" self.log.error('Encountered Error response from k8s list namespaced pod stream => %s', event) raw_object = event['raw_object'] if raw_object['code'] == 410: self.log.info( - 'Kubernetes resource version is too old, must reset to 0 => %s', (raw_object['message'],) + 'Kubernetes resource version is too old, ' + 'relisting pods to get the latest version. Error => %s', + (raw_object['message'],), ) - # Return resource version 0 - return '0' + return get_latest_resource_version(kube_client, self.namespace) raise AirflowException( 'Kubernetes failure for %s with code %s and message: %s' % (raw_object['reason'], raw_object['code'], raw_object['message']) @@ -261,7 +312,9 @@ def run_pod_async(self, pod: k8s.V1Pod, **kwargs): return resp def _make_kube_watcher(self) -> KubernetesJobWatcher: - resource_version = ResourceVersion().resource_version + resource_version = ResourceVersion( + kube_client=self.kube_client, namespace=self.kube_config.kube_namespace + ).resource_version watcher = KubernetesJobWatcher( watcher_queue=self.watcher_queue, namespace=self.kube_config.kube_namespace, @@ -535,7 +588,6 @@ def sync(self) -> None: if not self.task_queue: raise AirflowException(NOT_STARTED_MESSAGE) self.kube_scheduler.sync() - last_resource_version = None while True: try: @@ -558,9 +610,8 @@ def sync(self) -> None: self.result_queue.task_done() except Empty: break - - resource_instance = ResourceVersion() - resource_instance.resource_version = last_resource_version or resource_instance.resource_version + if last_resource_version: + ResourceVersion(resource_version=last_resource_version) for _ in range(self.kube_config.worker_pods_creation_batch_size): try: diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index e48d6ce9e3a81..7a95861bac618 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -37,8 +37,10 @@ AirflowKubernetesScheduler, KubernetesExecutor, KubernetesJobWatcher, + ResourceVersion, create_pod_id, get_base_pod_from_template, + get_latest_resource_version, ) from airflow.kubernetes import pod_generator from airflow.kubernetes.kubernetes_helper_functions import annotations_to_key @@ -756,18 +758,18 @@ def test_process_status_catchall(self): self._run() self.watcher.watcher_queue.put.assert_not_called() - @mock.patch.object(KubernetesJobWatcher, 'process_error') - def test_process_error_event_for_410(self, mock_process_error): + @mock.patch('airflow.executors.kubernetes_executor.get_latest_resource_version') + def test_process_error_event_for_410(self, mock_get_resource_version): + mock_get_resource_version.return_value = '43334' message = "too old resource version: 27272 (43334)" self.pod.status.phase = 'Pending' - self.pod.metadata.resource_version = '0' - mock_process_error.return_value = '0' + self.pod.metadata.resource_version = '43334' raw_object = {"code": 410, "message": message} self.events.append({"type": "ERROR", "object": self.pod, "raw_object": raw_object}) self._run() - mock_process_error.assert_called_once_with(self.events[0]) + mock_get_resource_version.assert_called_once() - def test_process_error_event_for_raise_if_not_410(self): + def test_process_error_event_raise_if_not_410(self): message = "Failure message" self.pod.status.phase = 'Pending' raw_object = {"code": 422, "message": message, "reason": "Test"} @@ -779,3 +781,53 @@ def test_process_error_event_for_raise_if_not_410(self): raw_object['code'], raw_object['message'], ) + + @mock.patch('airflow.executors.kubernetes_executor.get_latest_resource_version') + @mock.patch.object(KubernetesJobWatcher, '_run') + def test_apiexception_for_410_is_handled(self, mock_run, mock_get_resource_version): + self.events.append({"type": 'MODIFIED', "object": self.pod}) + mock_run.side_effect = mock.Mock(side_effect=ApiException(status=410, reason='too old error')) + with self.assertRaises(ApiException): + self.watcher._run( + kube_client=self.kube_client, + resource_version=self.watcher.resource_version, + scheduler_job_id=self.watcher.scheduler_job_id, + kube_config=self.watcher.kube_config, + ) + mock_get_resource_version.assert_called_once() + + +class TestResourceVersion(unittest.TestCase): + # pylint: disable=no-member + def tearDown(self) -> None: + ResourceVersion._drop() + + def test_can_update_with_resource_version_arg(self): + resource_instance = ResourceVersion(resource_version='4567') + assert resource_instance.resource_version == '4567' + + @mock.patch('airflow.executors.kubernetes_executor.get_latest_resource_version') + def test_different_instance_share_state(self, mock_get_resource_version): + kube_client = mock.MagicMock() + mock_get_resource_version.return_value = '4566' + resource_instance = ResourceVersion(kube_client=kube_client, namespace='mynamespace') + resource_instance2 = ResourceVersion(kube_client=kube_client, namespace='mynamespace') + assert resource_instance.resource_version == '4566' + assert resource_instance2.resource_version == '4566' + resource_instance3 = ResourceVersion(resource_version='6787') + resource_instance4 = ResourceVersion(kube_client=kube_client, namespace='mynamespace') + assert resource_instance.resource_version == '6787' + assert resource_instance2.resource_version == '6787' + assert resource_instance3.resource_version == '6787' + assert resource_instance4.resource_version == '6787' + mock_get_resource_version.assert_called_once() + + +class TestGetLatestResourceVersion(unittest.TestCase): + def test_get_latest_resource_version(self): + kube_client = mock.MagicMock() + list_namespaced_pod = kube_client.list_namespaced_pod + list_namespaced_pod.return_value.metadata.resource_version = '5688' + resource_version = get_latest_resource_version(kube_client, 'mynamespace') + assert list_namespaced_pod.called + assert resource_version == '5688'