From 743bf5a0ae1279c96d018aad54dcce108f16dc96 Mon Sep 17 00:00:00 2001 From: Syed Hussaain <103602455+syedahsn@users.noreply.github.com> Date: Tue, 20 Jun 2023 14:20:39 -0700 Subject: [PATCH] Add custom waiters to EMR Serverless (#30463) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Move waiter logic to utils folder --------- Co-authored-by: Raphaƫl Vandon --- airflow/providers/amazon/aws/operators/emr.py | 279 ++++++++----- .../amazon/aws/utils/waiter_with_logging.py | 90 +++++ .../amazon/aws/waiters/emr-serverless.json | 139 +++++++ .../aws/operators/test_emr_serverless.py | 369 +++++++++++++----- .../aws/utils/test_waiter_with_logging.py | 304 +++++++++++++++ 5 files changed, 976 insertions(+), 205 deletions(-) create mode 100644 airflow/providers/amazon/aws/utils/waiter_with_logging.py create mode 100644 tests/providers/amazon/aws/utils/test_waiter_with_logging.py diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index b8ca53226e807..9fdad3b918f47 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -34,6 +34,7 @@ EmrTerminateJobFlowTrigger, ) from airflow.providers.amazon.aws.utils.waiter import waiter +from airflow.providers.amazon.aws.utils.waiter_with_logging import wait from airflow.utils.helpers import exactly_one, prune_dict from airflow.utils.types import NOTSET, ArgNotSet @@ -945,10 +946,13 @@ class EmrServerlessCreateApplicationOperator(BaseOperator): Its value must be unique for each request. :param config: Optional dictionary for arbitrary parameters to the boto API create_application call. :param aws_conn_id: AWS connection to use - :param waiter_countdown: Total amount of time, in seconds, the operator will wait for + :param waiter_countdown: (deprecated) Total amount of time, in seconds, the operator will wait for the application to start. Defaults to 25 minutes. - :param waiter_check_interval_seconds: Number of seconds between polling the state of the application. - Defaults to 60 seconds. + :param waiter_check_interval_seconds: (deprecated) Number of seconds between polling the state + of the application. Defaults to 60 seconds. + :waiter_max_attempts: Number of times the waiter should poll the application to check the state. + If not set, the waiter will use its default value. + :param waiter_delay: Number of seconds between polling the state of the application. """ def __init__( @@ -959,18 +963,41 @@ def __init__( config: dict | None = None, wait_for_completion: bool = True, aws_conn_id: str = "aws_default", - waiter_countdown: int = 25 * 60, - waiter_check_interval_seconds: int = 60, + waiter_countdown: int | ArgNotSet = NOTSET, + waiter_check_interval_seconds: int | ArgNotSet = NOTSET, + waiter_max_attempts: int | ArgNotSet = NOTSET, + waiter_delay: int | ArgNotSet = NOTSET, **kwargs, ): + if waiter_check_interval_seconds is NOTSET: + waiter_delay = 60 if waiter_delay is NOTSET else waiter_delay + else: + waiter_delay = waiter_check_interval_seconds if waiter_delay is NOTSET else waiter_delay + warnings.warn( + "The parameter waiter_check_interval_seconds has been deprecated to standardize " + "naming conventions. Please use waiter_delay instead. In the " + "future this will default to None and defer to the waiter's default value." + ) + if waiter_countdown is NOTSET: + waiter_max_attempts = 25 if waiter_max_attempts is NOTSET else waiter_max_attempts + else: + if waiter_max_attempts is NOTSET: + # ignoring mypy because it doesn't like ArgNotSet as an operand, but neither variables + # are of type ArgNotSet at this point. + waiter_max_attempts = waiter_countdown // waiter_delay # type: ignore[operator] + warnings.warn( + "The parameter waiter_countdown has been deprecated to standardize " + "naming conventions. Please use waiter_max_attempts instead. In the " + "future this will default to None and defer to the waiter's default value." + ) self.aws_conn_id = aws_conn_id self.release_label = release_label self.job_type = job_type self.wait_for_completion = wait_for_completion self.kwargs = kwargs self.config = config or {} - self.waiter_countdown = waiter_countdown - self.waiter_check_interval_seconds = waiter_check_interval_seconds + self.waiter_max_attempts = int(waiter_max_attempts) # type: ignore[arg-type] + self.waiter_delay = int(waiter_delay) # type: ignore[arg-type] super().__init__(**kwargs) self.client_request_token = client_request_token or str(uuid4()) @@ -993,37 +1020,31 @@ def execute(self, context: Context) -> str | None: raise AirflowException(f"Application Creation failed: {response}") self.log.info("EMR serverless application created: %s", application_id) + waiter = self.hook.get_waiter("serverless_app_created") - # This should be replaced with a boto waiter when available. - waiter( - get_state_callable=self.hook.conn.get_application, - get_state_args={"applicationId": application_id}, - parse_response=["application", "state"], - desired_state={"CREATED"}, - failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES, - object_type="application", - action="created", - countdown=self.waiter_countdown, - check_interval_seconds=self.waiter_check_interval_seconds, + wait( + waiter=waiter, + waiter_delay=self.waiter_delay, + max_attempts=self.waiter_max_attempts, + args={"applicationId": application_id}, + failure_message="Serverless Application creation failed", + status_message="Serverless Application status is", + status_args=["application.state", "application.stateDetails"], ) - self.log.info("Starting application %s", application_id) self.hook.conn.start_application(applicationId=application_id) if self.wait_for_completion: - # This should be replaced with a boto waiter when available. - waiter( - get_state_callable=self.hook.conn.get_application, - get_state_args={"applicationId": application_id}, - parse_response=["application", "state"], - desired_state={"STARTED"}, - failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES, - object_type="application", - action="started", - countdown=self.waiter_countdown, - check_interval_seconds=self.waiter_check_interval_seconds, + waiter = self.hook.get_waiter("serverless_app_started") + wait( + waiter=waiter, + max_attempts=self.waiter_max_attempts, + waiter_delay=self.waiter_delay, + args={"applicationId": application_id}, + failure_message="Serverless Application failed to start", + status_message="Serverless Application status is", + status_args=["application.state", "application.stateDetails"], ) - return application_id @@ -1047,10 +1068,13 @@ class EmrServerlessStartJobOperator(BaseOperator): when waiting for the application be to in the ``STARTED`` state. :param aws_conn_id: AWS connection to use. :param name: Name for the EMR Serverless job. If not provided, a default name will be assigned. - :param waiter_countdown: Total amount of time, in seconds, the operator will wait for + :param waiter_countdown: (deprecated) Total amount of time, in seconds, the operator will wait for the job finish. Defaults to 25 minutes. - :param waiter_check_interval_seconds: Number of seconds between polling the state of the job. + :param waiter_check_interval_seconds: (deprecated) Number of seconds between polling the state of the job. Defaults to 60 seconds. + :waiter_max_attempts: Number of times the waiter should poll the application to check the state. + If not set, the waiter will use its default value. + :param waiter_delay: Number of seconds between polling the state of the job run. """ template_fields: Sequence[str] = ( @@ -1077,10 +1101,33 @@ def __init__( wait_for_completion: bool = True, aws_conn_id: str = "aws_default", name: str | None = None, - waiter_countdown: int = 25 * 60, - waiter_check_interval_seconds: int = 60, + waiter_countdown: int | ArgNotSet = NOTSET, + waiter_check_interval_seconds: int | ArgNotSet = NOTSET, + waiter_max_attempts: int | ArgNotSet = NOTSET, + waiter_delay: int | ArgNotSet = NOTSET, **kwargs, ): + if waiter_check_interval_seconds is NOTSET: + waiter_delay = 60 if waiter_delay is NOTSET else waiter_delay + else: + waiter_delay = waiter_check_interval_seconds if waiter_delay is NOTSET else waiter_delay + warnings.warn( + "The parameter waiter_check_interval_seconds has been deprecated to standardize " + "naming conventions. Please use waiter_delay instead. In the " + "future this will default to None and defer to the waiter's default value." + ) + if waiter_countdown is NOTSET: + waiter_max_attempts = 25 if waiter_max_attempts is NOTSET else waiter_max_attempts + else: + if waiter_max_attempts is NOTSET: + # ignoring mypy because it doesn't like ArgNotSet as an operand, but neither variables + # are of type ArgNotSet at this point. + waiter_max_attempts = waiter_countdown // waiter_delay # type: ignore[operator] + warnings.warn( + "The parameter waiter_countdown has been deprecated to standardize " + "naming conventions. Please use waiter_max_attempts instead. In the " + "future this will default to None and defer to the waiter's default value." + ) self.aws_conn_id = aws_conn_id self.application_id = application_id self.execution_role_arn = execution_role_arn @@ -1089,8 +1136,8 @@ def __init__( self.wait_for_completion = wait_for_completion self.config = config or {} self.name = name or self.config.pop("name", f"emr_serverless_job_airflow_{uuid4()}") - self.waiter_countdown = waiter_countdown - self.waiter_check_interval_seconds = waiter_check_interval_seconds + self.waiter_max_attempts = int(waiter_max_attempts) # type: ignore[arg-type] + self.waiter_delay = int(waiter_delay) # type: ignore[arg-type] self.job_id: str | None = None super().__init__(**kwargs) @@ -1107,17 +1154,16 @@ def execute(self, context: Context) -> str | None: app_state = self.hook.conn.get_application(applicationId=self.application_id)["application"]["state"] if app_state not in EmrServerlessHook.APPLICATION_SUCCESS_STATES: self.hook.conn.start_application(applicationId=self.application_id) - - waiter( - get_state_callable=self.hook.conn.get_application, - get_state_args={"applicationId": self.application_id}, - parse_response=["application", "state"], - desired_state={"STARTED"}, - failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES, - object_type="application", - action="started", - countdown=self.waiter_countdown, - check_interval_seconds=self.waiter_check_interval_seconds, + waiter = self.hook.get_waiter("serverless_app_started") + + wait( + waiter=waiter, + max_attempts=self.waiter_max_attempts, + waiter_delay=self.waiter_delay, + args={"applicationId": self.application_id}, + failure_message="Serverless Application failed to start", + status_message="Serverless Application status is", + status_args=["application.state", "application.stateDetails"], ) response = self.hook.conn.start_job_run( @@ -1136,21 +1182,17 @@ def execute(self, context: Context) -> str | None: self.job_id = response["jobRunId"] self.log.info("EMR serverless job started: %s", self.job_id) if self.wait_for_completion: - # This should be replaced with a boto waiter when available. - waiter( - get_state_callable=self.hook.conn.get_job_run, - get_state_args={ - "applicationId": self.application_id, - "jobRunId": self.job_id, - }, - parse_response=["jobRun", "state"], - desired_state=EmrServerlessHook.JOB_SUCCESS_STATES, - failure_states=EmrServerlessHook.JOB_FAILURE_STATES, - object_type="job", - action="run", - countdown=self.waiter_countdown, - check_interval_seconds=self.waiter_check_interval_seconds, + waiter = self.hook.get_waiter("serverless_job_completed") + wait( + waiter=waiter, + max_attempts=self.waiter_max_attempts, + waiter_delay=self.waiter_delay, + args={"applicationId": self.application_id, "jobRunId": self.job_id}, + failure_message="Serverless Job failed", + status_message="Serverless Job status is", + status_args=["jobRun.state", "jobRun.stateDetails"], ) + return self.job_id def on_kill(self) -> None: @@ -1180,8 +1222,8 @@ def on_kill(self) -> None: failure_states=set(), object_type="job", action="cancelled", - countdown=self.waiter_countdown, - check_interval_seconds=self.waiter_check_interval_seconds, + countdown=self.waiter_delay * self.waiter_max_attempts, + check_interval_seconds=self.waiter_delay, ) @@ -1213,16 +1255,39 @@ def __init__( application_id: str, wait_for_completion: bool = True, aws_conn_id: str = "aws_default", - waiter_countdown: int = 5 * 60, - waiter_check_interval_seconds: int = 30, + waiter_countdown: int | ArgNotSet = NOTSET, + waiter_check_interval_seconds: int | ArgNotSet = NOTSET, + waiter_max_attempts: int | ArgNotSet = NOTSET, + waiter_delay: int | ArgNotSet = NOTSET, force_stop: bool = False, **kwargs, ): + if waiter_check_interval_seconds is NOTSET: + waiter_delay = 60 if waiter_delay is NOTSET else waiter_delay + else: + waiter_delay = waiter_check_interval_seconds if waiter_delay is NOTSET else waiter_delay + warnings.warn( + "The parameter waiter_check_interval_seconds has been deprecated to standardize " + "naming conventions. Please use waiter_delay instead. In the " + "future this will default to None and defer to the waiter's default value." + ) + if waiter_countdown is NOTSET: + waiter_max_attempts = 25 if waiter_max_attempts is NOTSET else waiter_max_attempts + else: + if waiter_max_attempts is NOTSET: + # ignoring mypy because it doesn't like ArgNotSet as an operand, but neither variables + # are of type ArgNotSet at this point. + waiter_max_attempts = waiter_countdown // waiter_delay # type: ignore[operator] + warnings.warn( + "The parameter waiter_countdown has been deprecated to standardize " + "naming conventions. Please use waiter_max_attempts instead. In the " + "future this will default to None and defer to the waiter's default value." + ) self.aws_conn_id = aws_conn_id self.application_id = application_id self.wait_for_completion = wait_for_completion - self.waiter_countdown = waiter_countdown - self.waiter_check_interval_seconds = waiter_check_interval_seconds + self.waiter_max_attempts = int(waiter_max_attempts) # type: ignore[arg-type] + self.waiter_delay = int(waiter_delay) # type: ignore[arg-type] self.force_stop = force_stop super().__init__(**kwargs) @@ -1238,27 +1303,23 @@ def execute(self, context: Context) -> None: self.hook.cancel_running_jobs( self.application_id, waiter_config={ - "Delay": self.waiter_check_interval_seconds, - "MaxAttempts": self.waiter_countdown / self.waiter_check_interval_seconds, + "Delay": self.waiter_delay, + "MaxAttempts": self.waiter_max_attempts, }, ) self.hook.conn.stop_application(applicationId=self.application_id) if self.wait_for_completion: - # This should be replaced with a boto waiter when available. - waiter( - get_state_callable=self.hook.conn.get_application, - get_state_args={ - "applicationId": self.application_id, - }, - parse_response=["application", "state"], - desired_state=EmrServerlessHook.APPLICATION_FAILURE_STATES, - failure_states=set(), - object_type="application", - action="stopped", - countdown=self.waiter_countdown, - check_interval_seconds=self.waiter_check_interval_seconds, + waiter = self.hook.get_waiter("serverless_app_stopped") + wait( + waiter=waiter, + max_attempts=self.waiter_max_attempts, + waiter_delay=self.waiter_delay, + args={"applicationId": self.application_id}, + failure_message="Error stopping application", + status_message="Serverless Application status is", + status_args=["application.state", "application.stateDetails"], ) self.log.info("EMR serverless application %s stopped successfully", self.application_id) @@ -1292,11 +1353,34 @@ def __init__( application_id: str, wait_for_completion: bool = True, aws_conn_id: str = "aws_default", - waiter_countdown: int = 25 * 60, - waiter_check_interval_seconds: int = 60, + waiter_countdown: int | ArgNotSet = NOTSET, + waiter_check_interval_seconds: int | ArgNotSet = NOTSET, + waiter_max_attempts: int | ArgNotSet = NOTSET, + waiter_delay: int | ArgNotSet = NOTSET, force_stop: bool = False, **kwargs, ): + if waiter_check_interval_seconds is NOTSET: + waiter_delay = 60 if waiter_delay is NOTSET else waiter_delay + else: + waiter_delay = waiter_check_interval_seconds if waiter_delay is NOTSET else waiter_delay + warnings.warn( + "The parameter waiter_check_interval_seconds has been deprecated to standardize " + "naming conventions. Please use waiter_delay instead. In the " + "future this will default to None and defer to the waiter's default value." + ) + if waiter_countdown is NOTSET: + waiter_max_attempts = 25 if waiter_max_attempts is NOTSET else waiter_max_attempts + else: + if waiter_max_attempts is NOTSET: + # ignoring mypy because it doesn't like ArgNotSet as an operand, but neither variables + # are of type ArgNotSet at this point. + waiter_max_attempts = waiter_countdown // waiter_delay # type: ignore[operator] + warnings.warn( + "The parameter waiter_countdown has been deprecated to standardize " + "naming conventions. Please use waiter_max_attempts instead. In the " + "future this will default to None and defer to the waiter's default value." + ) self.wait_for_delete_completion = wait_for_completion # super stops the app super().__init__( @@ -1304,8 +1388,8 @@ def __init__( # when deleting an app, we always need to wait for it to stop before we can call delete() wait_for_completion=True, aws_conn_id=aws_conn_id, - waiter_countdown=waiter_countdown, - waiter_check_interval_seconds=waiter_check_interval_seconds, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, force_stop=force_stop, **kwargs, ) @@ -1321,17 +1405,16 @@ def execute(self, context: Context) -> None: raise AirflowException(f"Application deletion failed: {response}") if self.wait_for_delete_completion: - # This should be replaced with a boto waiter when available. - waiter( - get_state_callable=self.hook.conn.get_application, - get_state_args={"applicationId": self.application_id}, - parse_response=["application", "state"], - desired_state={"TERMINATED"}, - failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES, - object_type="application", - action="deleted", - countdown=self.waiter_countdown, - check_interval_seconds=self.waiter_check_interval_seconds, + waiter = self.hook.get_waiter("serverless_app_terminated") + + wait( + waiter=waiter, + max_attempts=self.waiter_max_attempts, + waiter_delay=self.waiter_delay, + args={"applicationId": self.application_id}, + failure_message="Error terminating application", + status_message="Serverless Application status is", + status_args=["application.state", "application.stateDetails"], ) self.log.info("EMR serverless application deleted") diff --git a/airflow/providers/amazon/aws/utils/waiter_with_logging.py b/airflow/providers/amazon/aws/utils/waiter_with_logging.py new file mode 100644 index 0000000000000..8c9e33077f6ed --- /dev/null +++ b/airflow/providers/amazon/aws/utils/waiter_with_logging.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +import time + +import jmespath +from botocore.exceptions import WaiterError +from botocore.waiter import Waiter + +from airflow.exceptions import AirflowException + + +def wait( + waiter: Waiter, + waiter_delay: int, + max_attempts: int, + args: dict, + failure_message: str, + status_message: str, + status_args: list, +) -> None: + """ + Use a boto waiter to poll an AWS service for the specified state. Although this function + uses boto waiters to poll the state of the service, it logs the response of the service + after every attempt, which is not currently supported by boto waiters. + + :param waiter: The boto waiter to use. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param max_attempts: The maximum number of attempts to be made. + :param args: The arguments to pass to the waiter. + :param failure_message: The message to log if a failure state is reached. + :param status_message: The message logged when printing the status of the service. + :param status_args: A list containing the arguments to retrieve status information from + the waiter response. + e.g. + response = {"Cluster": {"state": "CREATING"}} + status_args = ["Cluster.state"] + + response = { + "Clusters": [{"state": "CREATING", "details": "User initiated."},] + } + status_args = ["Clusters[0].state", "Clusters[0].details"] + """ + log = logging.getLogger(__name__) + attempt = 0 + while True: + attempt += 1 + try: + waiter.wait(**args, WaiterConfig={"MaxAttempts": 1}) + break + except WaiterError as error: + if "terminal failure" in str(error): + raise AirflowException(f"{failure_message}: {error}") + status_string = _format_status_string(status_args, error.last_response) + log.info("%s: %s", status_message, status_string) + time.sleep(waiter_delay) + + if attempt >= max_attempts: + raise AirflowException("Waiter error: max attempts reached") + + +def _format_status_string(args, response): + """ + Loops through the supplied args list and generates a string + which contains values from the waiter response. + """ + values = [] + for arg in args: + value = jmespath.search(arg, response) + if value is not None and value != "": + values.append(str(value)) + + return " - ".join(values) diff --git a/airflow/providers/amazon/aws/waiters/emr-serverless.json b/airflow/providers/amazon/aws/waiters/emr-serverless.json index a77d07f243687..4066109382a6a 100644 --- a/airflow/providers/amazon/aws/waiters/emr-serverless.json +++ b/airflow/providers/amazon/aws/waiters/emr-serverless.json @@ -13,6 +13,145 @@ "state": "success" } ] + }, + "serverless_app_created": { + "operation": "GetApplication", + "delay": 60, + "maxAttempts": 1500, + "acceptors": [ + { + "matcher": "path", + "argument": "application.state", + "expected": "CREATED", + "state": "success" + }, + { + "matcher": "path", + "argument": "application.state", + "expected": "TERMINATED", + "state": "failure" + } + ] + }, + "serverless_app_started": { + "operation": "GetApplication", + "delay": 60, + "maxAttempts": 1500, + "acceptors": [ + { + "matcher": "path", + "argument": "application.state", + "expected": "STARTED", + "state": "success" + }, + { + "matcher": "path", + "argument": "application.state", + "expected": "TERMINATED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "application.state", + "expected": "STOPPED", + "state": "failure" + } + ] + }, + "serverless_app_stopped": { + "operation": "GetApplication", + "delay": 60, + "maxAttempts": 1500, + "acceptors": [ + { + "matcher": "path", + "argument": "application.state", + "expected": "STOPPED", + "state": "success" + }, + { + "matcher": "path", + "argument": "application.state", + "expected": "TERMINATED", + "state": "failure" + } + ] + }, + "serverless_app_terminated": { + "operation": "GetApplication", + "delay": 60, + "maxAttempts": 1500, + "acceptors": [ + { + "matcher": "path", + "argument": "application.state", + "expected": "TERMINATED", + "state": "success" + } + ] + }, + "serverless_job_completed": { + "operation": "GetJobRun", + "delay": 60, + "maxAttempts": 1500, + "acceptors": [ + { + "matcher": "path", + "argument": "jobRun.state", + "expected": "SUCCESS", + "state": "success" + }, + { + "matcher": "path", + "argument": "jobRun.state", + "expected": "FAILED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "jobRun.state", + "expected": "CANCELLED", + "state": "failure" + } + ] + }, + "serverless_job_running": { + "operation": "GetJobRun", + "delay": 60, + "maxAttempts": 1500, + "acceptors": [ + { + "matcher": "path", + "argument": "jobRun.state", + "expected": "RUNNING", + "state": "success" + }, + { + "matcher": "path", + "argument": "jobRun.state", + "expected": "FAILED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "jobRun.state", + "expected": "CANCELLED", + "state": "failure" + } + ] + }, + "serverless_app_deleted": { + "operation": "GetApplication", + "delay": 60, + "maxAttempts": 1500, + "acceptors": [ + { + "matcher": "path", + "argument": "application.state", + "expected": "TERMINATED", + "state": "success" + } + ] } } } diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py index 6889a374ce7fd..8cb4eb1707b53 100644 --- a/tests/providers/amazon/aws/operators/test_emr_serverless.py +++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py @@ -21,14 +21,17 @@ from uuid import UUID import pytest +from botocore.exceptions import WaiterError from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook from airflow.providers.amazon.aws.operators.emr import ( EmrServerlessCreateApplicationOperator, EmrServerlessDeleteApplicationOperator, EmrServerlessStartJobOperator, EmrServerlessStopApplicationOperator, ) +from airflow.utils.types import NOTSET task_id = "test_emr_serverless_task_id" application_id = "test_application_id" @@ -46,8 +49,10 @@ class TestEmrServerlessCreateApplicationOperator: - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_execute_successfully_with_wait_for_completion(self, mock_conn): + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_execute_successfully_with_wait_for_completion(self, mock_conn, mock_waiter): + mock_waiter().wait.return_value = True mock_conn.create_application.return_value = { "applicationId": application_id, "ResponseMetadata": {"HTTPStatusCode": 200}, @@ -63,6 +68,8 @@ def test_execute_successfully_with_wait_for_completion(self, mock_conn): job_type=job_type, client_request_token=client_request_token, config=config, + waiter_max_attempts=3, + waiter_delay=0, ) id = operator.execute(None) @@ -73,15 +80,22 @@ def test_execute_successfully_with_wait_for_completion(self, mock_conn): type=job_type, **config, ) + mock_waiter().wait.assert_called_with( + applicationId=application_id, + WaiterConfig={ + "MaxAttempts": 1, + }, + ) + assert mock_waiter().wait.call_count == 2 + mock_conn.start_application.assert_called_once_with(applicationId=application_id) assert id == application_id mock_conn.get_application.call_count == 2 - # @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter") - @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") def test_execute_successfully_no_wait_for_completion(self, mock_conn, mock_waiter): - mock_waiter.return_value = True + mock_waiter().wait.return_value = True mock_conn.create_application.return_value = { "applicationId": application_id, "ResponseMetadata": {"HTTPStatusCode": 200}, @@ -106,13 +120,11 @@ def test_execute_successfully_no_wait_for_completion(self, mock_conn, mock_waite ) mock_conn.start_application.assert_called_once_with(applicationId=application_id) - mock_waiter.assert_called_once() + mock_waiter().wait.assert_called_once() assert id == application_id - @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_failed_create_application_request(self, mock_conn, mock_waiter): - mock_waiter.return_value = True + @mock.patch.object(EmrServerlessHook, "conn") + def test_failed_create_application_request(self, mock_conn): mock_conn.create_application.return_value = { "applicationId": application_id, "ResponseMetadata": {"HTTPStatusCode": 404}, @@ -138,13 +150,19 @@ def test_failed_create_application_request(self, mock_conn, mock_waiter): **config, ) - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_failed_create_application(self, mock_conn): + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_failed_create_application(self, mock_conn, mock_get_waiter): + error = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={"application": {"state": "FAILED"}}, + ) + mock_get_waiter().wait.side_effect = error mock_conn.create_application.return_value = { "applicationId": application_id, "ResponseMetadata": {"HTTPStatusCode": 200}, } - mock_conn.get_application.return_value = {"application": {"state": "TERMINATED"}} operator = EmrServerlessCreateApplicationOperator( task_id=task_id, @@ -157,7 +175,7 @@ def test_failed_create_application(self, mock_conn): with pytest.raises(AirflowException) as ex_message: operator.execute(None) - assert "Application reached failure state" in str(ex_message.value) + assert "Serverless Application creation failed:" in str(ex_message.value) mock_conn.create_application.assert_called_once_with( clientToken=client_request_token, @@ -165,18 +183,51 @@ def test_failed_create_application(self, mock_conn): type=job_type, **config, ) - mock_conn.get_application.assert_called_once_with(applicationId=application_id) + mock_conn.create_application.return_value = { + "applicationId": application_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + error = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={"application": {"state": "TERMINATED"}}, + ) + mock_get_waiter().wait.side_effect = error + + operator = EmrServerlessCreateApplicationOperator( + task_id=task_id, + release_label=release_label, + job_type=job_type, + client_request_token=client_request_token, + config=config, + ) - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_failed_start_application(self, mock_conn): + with pytest.raises(AirflowException) as ex_message: + operator.execute(None) + + assert "Serverless Application creation failed:" in str(ex_message.value) + + mock_conn.create_application.assert_called_with( + clientToken=client_request_token, + releaseLabel=release_label, + type=job_type, + **config, + ) + mock_conn.create_application.call_count == 2 + + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_failed_start_application(self, mock_conn, mock_get_waiter): + error = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={"application": {"state": "TERMINATED"}}, + ) + mock_get_waiter().wait.side_effect = [True, error] mock_conn.create_application.return_value = { "applicationId": application_id, "ResponseMetadata": {"HTTPStatusCode": 200}, } - mock_conn.get_application.side_effect = [ - {"application": {"state": "CREATED"}}, - {"application": {"state": "TERMINATED"}}, - ] operator = EmrServerlessCreateApplicationOperator( task_id=task_id, @@ -189,7 +240,7 @@ def test_failed_start_application(self, mock_conn): with pytest.raises(AirflowException) as ex_message: operator.execute(None) - assert "Application reached failure state" in str(ex_message.value) + assert "Serverless Application failed to start:" in str(ex_message.value) mock_conn.create_application.assert_called_once_with( clientToken=client_request_token, @@ -197,12 +248,11 @@ def test_failed_start_application(self, mock_conn): type=job_type, **config, ) - mock_conn.get_application.call_count == 2 - @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") def test_no_client_request_token(self, mock_conn, mock_waiter): - mock_waiter.return_value = True + mock_waiter().wait.return_value = True mock_conn.create_application.return_value = { "applicationId": application_id, "ResponseMetadata": {"HTTPStatusCode": 200}, @@ -221,10 +271,16 @@ def test_no_client_request_token(self, mock_conn, mock_waiter): assert str(UUID(generated_client_token, version=4)) == generated_client_token - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_application_in_failure_state(self, mock_conn): + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_application_in_failure_state(self, mock_conn, mock_get_waiter): fail_state = "STOPPED" - mock_conn.get_application.return_value = {"application": {"state": fail_state}} + error = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={"application": {"state": fail_state}}, + ) + mock_get_waiter().wait.side_effect = [error] mock_conn.create_application.return_value = { "applicationId": application_id, "ResponseMetadata": {"HTTPStatusCode": 200}, @@ -241,7 +297,7 @@ def test_application_in_failure_state(self, mock_conn): with pytest.raises(AirflowException) as ex_message: operator.execute(None) - assert str(ex_message.value) == f"Application reached failure state {fail_state}." + assert str(ex_message.value) == f"Serverless Application creation failed: {error}" mock_conn.create_application.assert_called_once_with( clientToken=client_request_token, @@ -250,10 +306,39 @@ def test_application_in_failure_state(self, mock_conn): **config, ) + @pytest.mark.parametrize( + "waiter_delay, waiter_max_attempts, waiter_countdown, waiter_check_interval_seconds, expected", + [ + (NOTSET, NOTSET, NOTSET, NOTSET, [60, 25]), + (30, 10, NOTSET, NOTSET, [30, 10]), + (NOTSET, NOTSET, 30 * 15, 15, [15, 30]), + (10, 20, 30, 40, [10, 20]), + ], + ) + def test_create_application_waiter_params( + self, waiter_delay, waiter_max_attempts, waiter_countdown, waiter_check_interval_seconds, expected + ): + operator = EmrServerlessCreateApplicationOperator( + task_id=task_id, + release_label=release_label, + job_type=job_type, + client_request_token=client_request_token, + config=config, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + waiter_countdown=waiter_countdown, + waiter_check_interval_seconds=waiter_check_interval_seconds, + ) + assert operator.wait_for_completion is True + assert operator.waiter_delay == expected[0] + assert operator.waiter_max_attempts == expected[1] + class TestEmrServerlessStartJobOperator: - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_job_run_app_started(self, mock_conn): + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_job_run_app_started(self, mock_conn, mock_get_waiter): + mock_get_waiter().wait.return_value = True mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} mock_conn.start_job_run.return_value = { "jobRunId": job_run_id, @@ -283,18 +368,22 @@ def test_job_run_app_started(self, mock_conn): configurationOverrides=configuration_overrides, name=default_name, ) - mock_conn.get_job_run.assert_called_once_with(applicationId=application_id, jobRunId=job_run_id) - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_job_run_job_failed(self, mock_conn): + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_job_run_job_failed(self, mock_conn, mock_get_waiter): + error = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={"jobRun": {"state": "FAILED"}}, + ) + mock_get_waiter().wait.side_effect = [error] mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} mock_conn.start_job_run.return_value = { "jobRunId": job_run_id, "ResponseMetadata": {"HTTPStatusCode": 200}, } - mock_conn.get_job_run.return_value = {"jobRun": {"state": "FAILED"}} - operator = EmrServerlessStartJobOperator( task_id=task_id, client_request_token=client_request_token, @@ -307,9 +396,8 @@ def test_job_run_job_failed(self, mock_conn): with pytest.raises(AirflowException) as ex_message: id = operator.execute(None) assert id == job_run_id - assert "Job reached failure state FAILED." in str(ex_message.value) + assert "Serverless Job failed:" in str(ex_message.value) mock_conn.get_application.assert_called_once_with(applicationId=application_id) - mock_conn.get_job_run.assert_called_once_with(applicationId=application_id, jobRunId=job_run_id) mock_conn.start_job_run.assert_called_once_with( clientToken=client_request_token, applicationId=application_id, @@ -319,10 +407,10 @@ def test_job_run_job_failed(self, mock_conn): name=default_name, ) - @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_job_run_app_not_started(self, mock_conn, mock_waiter): - mock_waiter.return_value = True + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_job_run_app_not_started(self, mock_conn, mock_get_waiter): + mock_get_waiter().wait.return_value = True mock_conn.get_application.return_value = {"application": {"state": "CREATING"}} mock_conn.start_job_run.return_value = { "jobRunId": job_run_id, @@ -343,7 +431,7 @@ def test_job_run_app_not_started(self, mock_conn, mock_waiter): assert operator.wait_for_completion is True mock_conn.get_application.assert_called_once_with(applicationId=application_id) - assert mock_waiter.call_count == 2 + assert mock_get_waiter().wait.call_count == 2 assert id == job_run_id mock_conn.start_job_run.assert_called_once_with( clientToken=client_request_token, @@ -354,12 +442,21 @@ def test_job_run_app_not_started(self, mock_conn, mock_waiter): name=default_name, ) - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_job_run_app_not_started_app_failed(self, mock_conn): - mock_conn.get_application.side_effect = [ - {"application": {"state": "CREATING"}}, - {"application": {"state": "TERMINATED"}}, - ] + @mock.patch("time.sleep", return_value=True) + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_job_run_app_not_started_app_failed(self, mock_conn, mock_get_waiter, mock_time): + error1 = WaiterError( + name="test_name", + reason="test-reason", + last_response={"application": {"state": "CREATING", "stateDetails": "test-details"}}, + ) + error2 = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={"application": {"state": "TERMINATED", "stateDetails": "test-details"}}, + ) + mock_get_waiter().wait.side_effect = [error1, error2] mock_conn.start_job_run.return_value = { "jobRunId": job_run_id, "ResponseMetadata": {"HTTPStatusCode": 200}, @@ -375,15 +472,14 @@ def test_job_run_app_not_started_app_failed(self, mock_conn): ) with pytest.raises(AirflowException) as ex_message: operator.execute(None) - assert "Application reached failure state" in str(ex_message.value) + assert "Serverless Application failed to start:" in str(ex_message.value) assert operator.wait_for_completion is True - mock_conn.get_application.call_count == 2 - mock_conn.assert_not_called() + assert mock_get_waiter().wait.call_count == 2 - @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_job_run_app_not_started_no_wait_for_completion(self, mock_conn, mock_waiter): - mock_waiter.return_value = True + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_job_run_app_not_started_no_wait_for_completion(self, mock_conn, mock_get_waiter): + mock_get_waiter().wait.return_value = True mock_conn.get_application.return_value = {"application": {"state": "CREATING"}} mock_conn.start_job_run.return_value = { "jobRunId": job_run_id, @@ -403,7 +499,7 @@ def test_job_run_app_not_started_no_wait_for_completion(self, mock_conn, mock_wa id = operator.execute(None) mock_conn.get_application.assert_called_once_with(applicationId=application_id) - mock_waiter.assert_called_once() + mock_get_waiter().wait.assert_called_once() assert id == job_run_id mock_conn.start_job_run.assert_called_once_with( clientToken=client_request_token, @@ -414,10 +510,10 @@ def test_job_run_app_not_started_no_wait_for_completion(self, mock_conn, mock_wa name=default_name, ) - @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_job_run_app_started_no_wait_for_completion(self, mock_conn, mock_waiter): - mock_waiter.return_value = True + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_job_run_app_started_no_wait_for_completion(self, mock_conn, mock_get_waiter): + mock_get_waiter().wait.return_value = True mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} mock_conn.start_job_run.return_value = { "jobRunId": job_run_id, @@ -444,12 +540,12 @@ def test_job_run_app_started_no_wait_for_completion(self, mock_conn, mock_waiter configurationOverrides=configuration_overrides, name=default_name, ) - assert not mock_waiter.called + assert not mock_get_waiter().wait.called - @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_failed_start_job_run(self, mock_conn, mock_waiter): - mock_waiter.return_value = True + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_failed_start_job_run(self, mock_conn, mock_get_waiter): + mock_get_waiter().wait.return_value = True mock_conn.get_application.return_value = {"application": {"state": "CREATING"}} mock_conn.start_job_run.return_value = { "jobRunId": job_run_id, @@ -470,7 +566,7 @@ def test_failed_start_job_run(self, mock_conn, mock_waiter): assert "EMR serverless job failed to start:" in str(ex_message.value) mock_conn.get_application.assert_called_once_with(applicationId=application_id) - mock_waiter.assert_called_once() + mock_get_waiter().wait.assert_called_once() mock_conn.start_job_run.assert_called_once_with( clientToken=client_request_token, applicationId=application_id, @@ -480,15 +576,20 @@ def test_failed_start_job_run(self, mock_conn, mock_waiter): name=default_name, ) - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_start_job_run_fail_on_wait_for_completion(self, mock_conn): + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_start_job_run_fail_on_wait_for_completion(self, mock_conn, mock_get_waiter): + error = WaiterError( + name="mock_waiter_error", + reason="Waiter encountered a terminal failure state:", + last_response={"jobRun": {"state": "FAILED", "stateDetails": "Test Details"}}, + ) + mock_get_waiter().wait.side_effect = [error] mock_conn.get_application.return_value = {"application": {"state": "CREATED"}} mock_conn.start_job_run.return_value = { "jobRunId": job_run_id, "ResponseMetadata": {"HTTPStatusCode": 200}, } - mock_conn.get_job_run.return_value = {"jobRun": {"state": "FAILED"}} - operator = EmrServerlessStartJobOperator( task_id=task_id, client_request_token=client_request_token, @@ -501,7 +602,7 @@ def test_start_job_run_fail_on_wait_for_completion(self, mock_conn): with pytest.raises(AirflowException) as ex_message: operator.execute(None) - assert "Job reached failure state" in str(ex_message.value) + assert "Serverless Job failed:" in str(ex_message.value) mock_conn.get_application.call_count == 2 mock_conn.start_job_run.assert_called_once_with( clientToken=client_request_token, @@ -511,15 +612,17 @@ def test_start_job_run_fail_on_wait_for_completion(self, mock_conn): configurationOverrides=configuration_overrides, name=default_name, ) + mock_get_waiter().wait.assert_called_once() - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_start_job_default_name(self, mock_conn): + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_start_job_default_name(self, mock_conn, mock_get_waiter): mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} mock_conn.start_job_run.return_value = { "jobRunId": job_run_id, "ResponseMetadata": {"HTTPStatusCode": 200}, } - mock_conn.get_job_run.return_value = {"jobRun": {"state": "SUCCESS"}} + mock_get_waiter().wait.return_value = True operator = EmrServerlessStartJobOperator( task_id=task_id, @@ -543,15 +646,16 @@ def test_start_job_default_name(self, mock_conn): name=f"emr_serverless_job_airflow_{str(UUID(generated_name_uuid, version=4))}", ) - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_start_job_custom_name(self, mock_conn): + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_start_job_custom_name(self, mock_conn, mock_get_waiter): + mock_get_waiter().wait.return_value = True mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} custom_name = "test_name" mock_conn.start_job_run.return_value = { "jobRunId": job_run_id, "ResponseMetadata": {"HTTPStatusCode": 200}, } - mock_conn.get_job_run.return_value = {"jobRun": {"state": "SUCCESS"}} operator = EmrServerlessStartJobOperator( task_id=task_id, @@ -573,7 +677,7 @@ def test_start_job_custom_name(self, mock_conn): name=custom_name, ) - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + @mock.patch.object(EmrServerlessHook, "conn") def test_cancel_job_run(self, mock_conn): mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} mock_conn.start_job_run.return_value = { @@ -599,12 +703,39 @@ def test_cancel_job_run(self, mock_conn): jobRunId=id, ) + @pytest.mark.parametrize( + "waiter_delay, waiter_max_attempts, waiter_countdown, waiter_check_interval_seconds, expected", + [ + (NOTSET, NOTSET, NOTSET, NOTSET, [60, 25]), + (30, 10, NOTSET, NOTSET, [30, 10]), + (NOTSET, NOTSET, 30 * 15, 15, [15, 30]), + (10, 20, 30, 40, [10, 20]), + ], + ) + def test_start_job_waiter_params( + self, waiter_delay, waiter_max_attempts, waiter_countdown, waiter_check_interval_seconds, expected + ): + operator = EmrServerlessStartJobOperator( + task_id=task_id, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=job_driver, + configuration_overrides=configuration_overrides, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + waiter_countdown=waiter_countdown, + waiter_check_interval_seconds=waiter_check_interval_seconds, + ) + assert operator.wait_for_completion is True + assert operator.waiter_delay == expected[0] + assert operator.waiter_max_attempts == expected[1] + class TestEmrServerlessDeleteOperator: - @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_delete_application_with_wait_for_completion_successfully(self, mock_conn, mock_waiter): - mock_waiter.return_value = True + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_delete_application_with_wait_for_completion_successfully(self, mock_conn, mock_get_waiter): + mock_get_waiter().wait.return_value = True mock_conn.stop_application.return_value = {} mock_conn.delete_application.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}} @@ -615,14 +746,14 @@ def test_delete_application_with_wait_for_completion_successfully(self, mock_con operator.execute(None) assert operator.wait_for_completion is True - assert mock_waiter.call_count == 2 + assert mock_get_waiter().wait.call_count == 2 mock_conn.stop_application.assert_called_once() mock_conn.delete_application.assert_called_once_with(applicationId=application_id_delete_operator) - @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_delete_application_without_wait_for_completion_successfully(self, mock_conn, mock_waiter): - mock_waiter.return_value = True + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_delete_application_without_wait_for_completion_successfully(self, mock_conn, mock_get_waiter): + mock_get_waiter().wait.return_value = True mock_conn.stop_application.return_value = {} mock_conn.delete_application.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}} @@ -634,14 +765,14 @@ def test_delete_application_without_wait_for_completion_successfully(self, mock_ operator.execute(None) - mock_waiter.assert_called_once() + mock_get_waiter().wait.assert_called_once() mock_conn.stop_application.assert_called_once() mock_conn.delete_application.assert_called_once_with(applicationId=application_id_delete_operator) - @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_delete_application_failed_deletion(self, mock_conn, mock_waiter): - mock_waiter.return_value = True + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_delete_application_failed_deletion(self, mock_conn, mock_get_waiter): + mock_get_waiter().wait.return_value = True mock_conn.stop_application.return_value = {} mock_conn.delete_application.return_value = {"ResponseMetadata": {"HTTPStatusCode": 400}} @@ -653,37 +784,61 @@ def test_delete_application_failed_deletion(self, mock_conn, mock_waiter): assert "Application deletion failed:" in str(ex_message.value) - mock_waiter.assert_called_once() + mock_get_waiter().wait.assert_called_once() mock_conn.stop_application.assert_called_once() mock_conn.delete_application.assert_called_once_with(applicationId=application_id_delete_operator) + @pytest.mark.parametrize( + "waiter_delay, waiter_max_attempts, waiter_countdown, waiter_check_interval_seconds, expected", + [ + (NOTSET, NOTSET, NOTSET, NOTSET, [60, 25]), + (30, 10, NOTSET, NOTSET, [30, 10]), + (NOTSET, NOTSET, 30 * 15, 15, [15, 30]), + (10, 20, 30, 40, [10, 20]), + ], + ) + def test_delete_application_waiter_params( + self, waiter_delay, waiter_max_attempts, waiter_countdown, waiter_check_interval_seconds, expected + ): + operator = EmrServerlessDeleteApplicationOperator( + task_id=task_id, + application_id=application_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + waiter_countdown=waiter_countdown, + waiter_check_interval_seconds=waiter_check_interval_seconds, + ) + assert operator.wait_for_completion is True + assert operator.waiter_delay == expected[0] + assert operator.waiter_max_attempts == expected[1] + class TestEmrServerlessStopOperator: - @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_stop(self, mock_conn: MagicMock, mock_waiter: MagicMock): + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_stop(self, mock_conn: MagicMock, mock_get_waiter: MagicMock): + mock_get_waiter().wait.return_value = True operator = EmrServerlessStopApplicationOperator(task_id=task_id, application_id="test") operator.execute(None) - mock_waiter.assert_called_once() + mock_get_waiter().wait.assert_called_once() mock_conn.stop_application.assert_called_once() - @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") - def test_stop_no_wait(self, mock_conn: MagicMock, mock_waiter: MagicMock): + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "conn") + def test_stop_no_wait(self, mock_conn: MagicMock, mock_get_waiter: MagicMock): operator = EmrServerlessStopApplicationOperator( task_id=task_id, application_id="test", wait_for_completion=False ) operator.execute(None) - mock_waiter.assert_not_called() + mock_get_waiter().wait.assert_not_called() mock_conn.stop_application.assert_called_once() - @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter") @mock.patch.object(EmrServerlessStopApplicationOperator, "hook", new_callable=PropertyMock) - def test_force_stop(self, mock_hook: MagicMock, mock_waiter: MagicMock): + def test_force_stop(self, mock_hook: MagicMock): operator = EmrServerlessStopApplicationOperator( task_id=task_id, application_id="test", force_stop=True ) @@ -692,4 +847,4 @@ def test_force_stop(self, mock_hook: MagicMock, mock_waiter: MagicMock): mock_hook().cancel_running_jobs.assert_called_once() mock_hook().conn.stop_application.assert_called_once() - mock_waiter.assert_called_once() + mock_hook().get_waiter().wait.assert_called_once() diff --git a/tests/providers/amazon/aws/utils/test_waiter_with_logging.py b/tests/providers/amazon/aws/utils/test_waiter_with_logging.py new file mode 100644 index 0000000000000..2ca74936d7d71 --- /dev/null +++ b/tests/providers/amazon/aws/utils/test_waiter_with_logging.py @@ -0,0 +1,304 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from typing import Any +from unittest import mock + +import pytest +from botocore.exceptions import WaiterError + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.utils.waiter_with_logging import wait + + +def generate_response(state: str) -> dict[str, Any]: + return { + "Status": { + "State": state, + }, + } + + +class TestWaiter: + @mock.patch("time.sleep") + def test_wait(self, mock_sleep, caplog): + mock_sleep.return_value = True + mock_waiter = mock.MagicMock() + error = WaiterError( + name="test_waiter", + reason="test_reason", + last_response=generate_response("Pending"), + ) + mock_waiter.wait.side_effect = [error, error, True] + wait( + waiter=mock_waiter, + waiter_delay=123, + max_attempts=456, + args={"test_arg": "test_value"}, + failure_message="test failure message", + status_message="test status message", + status_args=["Status.State"], + ) + + mock_waiter.wait.assert_called_with( + **{"test_arg": "test_value"}, + WaiterConfig={ + "MaxAttempts": 1, + }, + ) + mock_waiter.wait.call_count == 3 + mock_sleep.assert_called_with(123) + assert ( + caplog.record_tuples + == [ + ( + "airflow.providers.amazon.aws.utils.waiter_with_logging", + logging.INFO, + "test status message: Pending", + ) + ] + * 2 + ) + + @mock.patch("time.sleep") + def test_wait_max_attempts_exceeded(self, mock_sleep, caplog): + mock_sleep.return_value = True + mock_waiter = mock.MagicMock() + error = WaiterError( + name="test_waiter", + reason="test_reason", + last_response=generate_response("Pending"), + ) + mock_waiter.wait.side_effect = [error, error, error] + with pytest.raises(AirflowException) as exc: + wait( + waiter=mock_waiter, + waiter_delay=123, + max_attempts=2, + args={"test_arg": "test_value"}, + failure_message="test failure message", + status_message="test status message", + status_args=["Status.State"], + ) + assert "Waiter error: max attempts reached" in str(exc) + mock_waiter.wait.assert_called_with( + **{"test_arg": "test_value"}, + WaiterConfig={ + "MaxAttempts": 1, + }, + ) + + mock_waiter.wait.call_count == 11 + mock_sleep.assert_called_with(123) + assert ( + caplog.record_tuples + == [ + ( + "airflow.providers.amazon.aws.utils.waiter_with_logging", + logging.INFO, + "test status message: Pending", + ) + ] + * 2 + ) + + @mock.patch("time.sleep") + def test_wait_with_failure(self, mock_sleep, caplog): + mock_sleep.return_value = True + mock_waiter = mock.MagicMock() + error = WaiterError( + name="test_waiter", + reason="test_reason", + last_response=generate_response("Pending"), + ) + failure_error = WaiterError( + name="test_waiter", + reason="terminal failure in waiter", + last_response=generate_response("Failure"), + ) + mock_waiter.wait.side_effect = [error, error, error, failure_error] + with pytest.raises(AirflowException) as exc: + wait( + waiter=mock_waiter, + waiter_delay=123, + max_attempts=10, + args={"test_arg": "test_value"}, + failure_message="test failure message", + status_message="test status message", + status_args=["Status.State"], + ) + assert "test failure message" in str(exc) + mock_waiter.wait.assert_called_with( + **{"test_arg": "test_value"}, + WaiterConfig={ + "MaxAttempts": 1, + }, + ) + assert mock_waiter.wait.call_count == 4 + assert ( + caplog.record_tuples + == [ + ( + "airflow.providers.amazon.aws.utils.waiter_with_logging", + logging.INFO, + "test status message: Pending", + ) + ] + * 3 + ) + + @mock.patch("time.sleep") + def test_wait_with_list_response(self, mock_sleep, caplog): + mock_sleep.return_value = True + mock_waiter = mock.MagicMock() + error = WaiterError( + name="test_waiter", + reason="test_reason", + last_response={ + "Clusters": [ + { + "Status": "Pending", + }, + { + "Status": "Pending", + }, + ] + }, + ) + mock_waiter.wait.side_effect = [error, error, True] + wait( + waiter=mock_waiter, + waiter_delay=123, + max_attempts=456, + args={"test_arg": "test_value"}, + failure_message="test failure message", + status_message="test status message", + status_args=["Clusters[0].Status"], + ) + + mock_waiter.wait.assert_called_with( + **{"test_arg": "test_value"}, + WaiterConfig={ + "MaxAttempts": 1, + }, + ) + mock_waiter.wait.call_count == 3 + mock_sleep.assert_called_with(123) + assert ( + caplog.record_tuples + == [ + ( + "airflow.providers.amazon.aws.utils.waiter_with_logging", + logging.INFO, + "test status message: Pending", + ) + ] + * 2 + ) + + @mock.patch("time.sleep") + def test_wait_with_incorrect_args(self, mock_sleep, caplog): + mock_sleep.return_value = True + mock_waiter = mock.MagicMock() + error = WaiterError( + name="test_waiter", + reason="test_reason", + last_response={ + "Clusters": [ + { + "Status": "Pending", + }, + { + "Status": "Pending", + }, + ] + }, + ) + mock_waiter.wait.side_effect = [error, error, True] + wait( + waiter=mock_waiter, + waiter_delay=123, + max_attempts=456, + args={"test_arg": "test_value"}, + failure_message="test failure message", + status_message="test status message", + status_args=["Clusters[0].State"], # this does not exist in the response + ) + + mock_waiter.wait.assert_called_with( + **{"test_arg": "test_value"}, + WaiterConfig={ + "MaxAttempts": 1, + }, + ) + mock_waiter.wait.call_count == 3 + mock_sleep.assert_called_with(123) + assert ( + caplog.record_tuples + == [ + ( + "airflow.providers.amazon.aws.utils.waiter_with_logging", + logging.INFO, + "test status message: ", + ) + ] + * 2 + ) + + @mock.patch("time.sleep") + def test_wait_with_multiple_args(self, mock_sleep, caplog): + mock_sleep.return_value = True + mock_waiter = mock.MagicMock() + error = WaiterError( + name="test_waiter", + reason="test_reason", + last_response={ + "Clusters": [ + { + "Status": "Pending", + "StatusDetails": "test_details", + "ClusterName": "test_name", + }, + ] + }, + ) + mock_waiter.wait.side_effect = [error, error, True] + wait( + waiter=mock_waiter, + waiter_delay=123, + max_attempts=456, + args={"test_arg": "test_value"}, + failure_message="test failure message", + status_message="test status message", + status_args=["Clusters[0].Status", "Clusters[0].StatusDetails", "Clusters[0].ClusterName"], + ) + mock_waiter.wait.call_count == 3 + mock_sleep.assert_called_with(123) + assert ( + caplog.record_tuples + == [ + ( + "airflow.providers.amazon.aws.utils.waiter_with_logging", + logging.INFO, + "test status message: Pending - test_details - test_name", + ) + ] + * 2 + )