From 83bd60fd97d4ca622adcbd7898d88880fee43054 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= <6774676+eumiro@users.noreply.github.com> Date: Wed, 9 Aug 2023 14:19:00 +0000 Subject: [PATCH] Refactor: Simplify code in providers/amazon (#33222) --- .../providers/amazon/aws/hooks/base_aws.py | 7 +- .../amazon/aws/hooks/batch_client.py | 91 +++++++++---------- .../providers/amazon/aws/hooks/datasync.py | 28 ++---- .../amazon/aws/hooks/redshift_data.py | 6 +- airflow/providers/amazon/aws/hooks/s3.py | 7 +- .../providers/amazon/aws/hooks/sagemaker.py | 31 +++---- .../amazon/aws/log/s3_task_handler.py | 7 +- .../amazon/aws/operators/redshift_cluster.py | 18 ++-- .../amazon/aws/secrets/secrets_manager.py | 5 +- .../amazon/aws/transfers/redshift_to_s3.py | 2 +- .../amazon/aws/transfers/s3_to_redshift.py | 2 +- .../amazon/aws/transfers/sql_to_s3.py | 2 +- .../providers/amazon/aws/triggers/batch.py | 8 +- airflow/providers/amazon/aws/triggers/ecs.py | 4 +- airflow/providers/amazon/aws/triggers/emr.py | 4 +- .../amazon/aws/triggers/sagemaker.py | 4 +- .../amazon/aws/utils/waiter_with_logging.py | 30 +++--- 17 files changed, 112 insertions(+), 144 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index 01748089c893..a908882e89ba 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -494,10 +494,9 @@ def _find_class_name(target_function_name: str) -> str: responsible with catching and handling those. """ stack = inspect.stack() - # Find the index of the most recent frame which called the provided function name. - target_frame_index = [frame.function for frame in stack].index(target_function_name) - # Pull that frame off the stack. - target_frame = stack[target_frame_index][0] + # Find the index of the most recent frame which called the provided function name + # and pull that frame off the stack. + target_frame = next(frame for frame in stack if frame.function == target_function_name)[0] # Get the local variables for that frame. frame_variables = target_frame.f_locals["self"] # Get the class object for that frame. diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py b/airflow/providers/amazon/aws/hooks/batch_client.py index e585297d3a62..87637be410a1 100644 --- a/airflow/providers/amazon/aws/hooks/batch_client.py +++ b/airflow/providers/amazon/aws/hooks/batch_client.py @@ -26,6 +26,7 @@ """ from __future__ import annotations +import itertools as it from random import uniform from time import sleep from typing import Callable @@ -343,8 +344,17 @@ def poll_job_status(self, job_id: str, match_status: list[str]) -> bool: :raises: AirflowException """ - retries = 0 - while True: + for retries in range(1 + self.max_retries): + if retries: + pause = self.exponential_delay(retries) + self.log.info( + "AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds", + job_id, + retries, + self.max_retries, + pause, + ) + self.delay(pause) job = self.get_job_description(job_id) job_status = job.get("status") @@ -354,23 +364,10 @@ def poll_job_status(self, job_id: str, match_status: list[str]) -> bool: job_status, match_status, ) - if job_status in match_status: return True - - if retries >= self.max_retries: - raise AirflowException(f"AWS Batch job ({job_id}) status checks exceed max_retries") - - retries += 1 - pause = self.exponential_delay(retries) - self.log.info( - "AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds", - job_id, - retries, - self.max_retries, - pause, - ) - self.delay(pause) + else: + raise AirflowException(f"AWS Batch job ({job_id}) status checks exceed max_retries") def get_job_description(self, job_id: str) -> dict: """ @@ -382,12 +379,21 @@ def get_job_description(self, job_id: str) -> dict: :raises: AirflowException """ - retries = 0 - while True: + for retries in range(self.status_retries): + if retries: + pause = self.exponential_delay(retries) + self.log.info( + "AWS Batch job (%s) description retry (%d of %d) in the next %.2f seconds", + job_id, + retries, + self.status_retries, + pause, + ) + self.delay(pause) + try: response = self.get_conn().describe_jobs(jobs=[job_id]) return self.parse_job_description(job_id, response) - except botocore.exceptions.ClientError as err: # Allow it to retry in case of exceeded quota limit of requests to AWS API if err.response.get("Error", {}).get("Code") != "TooManyRequestsException": @@ -398,23 +404,11 @@ def get_job_description(self, job_id: str) -> dict: "check Amazon Provider AWS Connection documentation for more details.", str(err), ) - - retries += 1 - if retries >= self.status_retries: - raise AirflowException( - f"AWS Batch job ({job_id}) description error: exceeded status_retries " - f"({self.status_retries})" - ) - - pause = self.exponential_delay(retries) - self.log.info( - "AWS Batch job (%s) description retry (%d of %d) in the next %.2f seconds", - job_id, - retries, - self.status_retries, - pause, + else: + raise AirflowException( + f"AWS Batch job ({job_id}) description error: exceeded status_retries " + f"({self.status_retries})" ) - self.delay(pause) @staticmethod def parse_job_description(job_id: str, response: dict) -> dict: @@ -476,7 +470,7 @@ def get_job_all_awslogs_info(self, job_id: str) -> list[dict[str, str]]: ) # If the user selected another logDriver than "awslogs", then CloudWatch logging is disabled. - if any([c.get("logDriver", "awslogs") != "awslogs" for c in log_configs]): + if any(c.get("logDriver", "awslogs") != "awslogs" for c in log_configs): self.log.warning( f"AWS Batch job ({job_id}) uses non-aws log drivers. AWS CloudWatch logging disabled." ) @@ -494,18 +488,17 @@ def get_job_all_awslogs_info(self, job_id: str) -> list[dict[str, str]]: # cross stream names with options (i.e. attempts X nodes) to generate all log infos result = [] - for stream in stream_names: - for option in log_options: - result.append( - { - "awslogs_stream_name": stream, - # If the user did not specify anything, the default settings are: - # awslogs-group = /aws/batch/job - # awslogs-region = `same as AWS Batch Job region` - "awslogs_group": option.get("awslogs-group", "/aws/batch/job"), - "awslogs_region": option.get("awslogs-region", self.conn_region_name), - } - ) + for stream, option in it.product(stream_names, log_options): + result.append( + { + "awslogs_stream_name": stream, + # If the user did not specify anything, the default settings are: + # awslogs-group = /aws/batch/job + # awslogs-region = `same as AWS Batch Job region` + "awslogs_group": option.get("awslogs-group", "/aws/batch/job"), + "awslogs_region": option.get("awslogs-region", self.conn_region_name), + } + ) return result @staticmethod diff --git a/airflow/providers/amazon/aws/hooks/datasync.py b/airflow/providers/amazon/aws/hooks/datasync.py index 1fd06eeefe3d..32fc7621531f 100644 --- a/airflow/providers/amazon/aws/hooks/datasync.py +++ b/airflow/providers/amazon/aws/hooks/datasync.py @@ -125,17 +125,11 @@ def get_location_arns( def _refresh_locations(self) -> None: """Refresh the local list of Locations.""" - self.locations = [] - next_token = None - while True: - if next_token: - locations = self.get_conn().list_locations(NextToken=next_token) - else: - locations = self.get_conn().list_locations() + locations = self.get_conn().list_locations() + self.locations = locations["Locations"] + while "NextToken" in locations: + locations = self.get_conn().list_locations(NextToken=locations["NextToken"]) self.locations.extend(locations["Locations"]) - if "NextToken" not in locations: - break - next_token = locations["NextToken"] def create_task( self, source_location_arn: str, destination_location_arn: str, **create_task_kwargs @@ -181,17 +175,11 @@ def delete_task(self, task_arn: str) -> None: def _refresh_tasks(self) -> None: """Refreshes the local list of Tasks.""" - self.tasks = [] - next_token = None - while True: - if next_token: - tasks = self.get_conn().list_tasks(NextToken=next_token) - else: - tasks = self.get_conn().list_tasks() + tasks = self.get_conn().list_tasks() + self.tasks = tasks["Tasks"] + while "NextToken" in tasks: + tasks = self.get_conn().list_tasks(NextToken=tasks["NextToken"]) self.tasks.extend(tasks["Tasks"]) - if "NextToken" not in tasks: - break - next_token = tasks["NextToken"] def get_task_arns_for_location_arns( self, diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py index a522d3e8c968..ba7a61c28f57 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_data.py +++ b/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -195,9 +195,9 @@ def get_table_primary_key( # we only select a single column (that is a string), # so safe to assume that there is only a single col in the record pk_columns += [y["stringValue"] for x in response["Records"] for y in x] - if "NextToken" not in response.keys(): - break - else: + if "NextToken" in response: token = response["NextToken"] + else: + break return pk_columns or None diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 8878f04e1b58..f0e98ed44887 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -1284,18 +1284,15 @@ def delete_bucket(self, bucket_name: str, force_delete: bool = False, max_retrie bucket and trying to delete the bucket. :return: None """ - tries_remaining = max_retries + 1 if force_delete: - while tries_remaining: + for retry in range(max_retries): bucket_keys = self.list_keys(bucket_name=bucket_name) if not bucket_keys: break - if tries_remaining <= max_retries: - # Avoid first loop + if retry: # Avoid first loop sleep(500) self.delete_objects(bucket=bucket_name, keys=bucket_keys) - tries_remaining -= 1 self.conn.delete_bucket(Bucket=bucket_name) diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py index 758ba7e9c5d5..b8247d9de97b 100644 --- a/airflow/providers/amazon/aws/hooks/sagemaker.py +++ b/airflow/providers/amazon/aws/hooks/sagemaker.py @@ -1143,15 +1143,12 @@ def stop_pipeline( if check_interval is None: check_interval = 10 - retries = 2 # i.e. 3 calls max, 1 initial + 2 retries - while True: + for retries in (2, 1, 0): try: self.conn.stop_pipeline_execution(PipelineExecutionArn=pipeline_exec_arn) - break except ClientError as ce: # this can happen if the pipeline was transitioning between steps at that moment - if ce.response["Error"]["Code"] == "ConflictException" and retries > 0: - retries = retries - 1 + if ce.response["Error"]["Code"] == "ConflictException" and retries: self.log.warning( "Got a conflict exception when trying to stop the pipeline, " "retrying %s more times. Error was: %s", @@ -1159,18 +1156,20 @@ def stop_pipeline( ce, ) time.sleep(0.3) # error is due to a race condition, so it should be very transient - continue - # we have to rely on the message to catch the right error here, because its type - # (ValidationException) is shared with other kinds of errors (e.g. badly formatted ARN) - if ( - not fail_if_not_running - and "Only pipelines with 'Executing' status can be stopped" - in ce.response["Error"]["Message"] - ): - self.log.warning("Cannot stop pipeline execution, as it was not running: %s", ce) else: - self.log.error(ce) - raise + # we have to rely on the message to catch the right error here, because its type + # (ValidationException) is shared with other kinds of errors (e.g. badly formatted ARN) + if ( + not fail_if_not_running + and "Only pipelines with 'Executing' status can be stopped" + in ce.response["Error"]["Message"] + ): + self.log.warning("Cannot stop pipeline execution, as it was not running: %s", ce) + break + else: + self.log.error(ce) + raise + else: break res = self.describe_pipeline_exec(pipeline_exec_arn) diff --git a/airflow/providers/amazon/aws/log/s3_task_handler.py b/airflow/providers/amazon/aws/log/s3_task_handler.py index a353d9b07ed8..4a813715fdd8 100644 --- a/airflow/providers/amazon/aws/log/s3_task_handler.py +++ b/airflow/providers/amazon/aws/log/s3_task_handler.py @@ -122,9 +122,10 @@ def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[list[str], l bucket, prefix = self.hook.parse_s3_url(s3url=os.path.join(self.remote_base, worker_log_rel_path)) keys = self.hook.list_keys(bucket_name=bucket, prefix=prefix) if keys: - keys = [f"s3://{bucket}/{key}" for key in keys] - messages.extend(["Found logs in s3:", *[f" * {x}" for x in sorted(keys)]]) - for key in sorted(keys): + keys = sorted(f"s3://{bucket}/{key}" for key in keys) + messages.append("Found logs in s3:") + messages.extend(f" * {key}" for key in keys) + for key in keys: logs.append(self.s3_read(key, return_error=True)) else: messages.append(f"No logs found on s3 for ti={ti}") diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index ed6aa79e9f2d..bdce036c312a 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -492,14 +492,14 @@ def __init__( def execute(self, context: Context): redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) self.log.info("Starting resume cluster") - while self._remaining_attempts >= 1: + while self._remaining_attempts: try: redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier) break except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error: - self._remaining_attempts = self._remaining_attempts - 1 + self._remaining_attempts -= 1 - if self._remaining_attempts > 0: + if self._remaining_attempts: self.log.error( "Unable to resume cluster. %d attempts remaining.", self._remaining_attempts ) @@ -580,14 +580,14 @@ def __init__( def execute(self, context: Context): redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) - while self._remaining_attempts >= 1: + while self._remaining_attempts: try: redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier) break except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error: - self._remaining_attempts = self._remaining_attempts - 1 + self._remaining_attempts -= 1 - if self._remaining_attempts > 0: + if self._remaining_attempts: self.log.error( "Unable to pause cluster. %d attempts remaining.", self._remaining_attempts ) @@ -669,7 +669,7 @@ def __init__( self.max_attempts = max_attempts def execute(self, context: Context): - while self._attempts >= 1: + while self._attempts: try: self.redshift_hook.delete_cluster( cluster_identifier=self.cluster_identifier, @@ -678,9 +678,9 @@ def execute(self, context: Context): ) break except self.redshift_hook.get_conn().exceptions.InvalidClusterStateFault: - self._attempts = self._attempts - 1 + self._attempts -= 1 - if self._attempts > 0: + if self._attempts: self.log.error("Unable to delete cluster. %d attempts remaining.", self._attempts) time.sleep(self._attempt_interval) else: diff --git a/airflow/providers/amazon/aws/secrets/secrets_manager.py b/airflow/providers/amazon/aws/secrets/secrets_manager.py index a5acf19a372e..f33f328c0021 100644 --- a/airflow/providers/amazon/aws/secrets/secrets_manager.py +++ b/airflow/providers/amazon/aws/secrets/secrets_manager.py @@ -217,10 +217,7 @@ def _standardize_secret_keys(self, secret: dict[str, Any]) -> dict[str, Any]: conn_d: dict[str, Any] = {} for conn_field, possible_words in possible_words_for_conn_fields.items(): - try: - conn_d[conn_field] = [v for k, v in secret.items() if k in possible_words][0] - except IndexError: - conn_d[conn_field] = None + conn_d[conn_field] = next((v for k, v in secret.items() if k in possible_words), None) return conn_d diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py index 45f1e1783c49..6b6b4890c093 100644 --- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py @@ -137,7 +137,7 @@ def __init__( if self.redshift_data_api_kwargs: for arg in ["sql", "parameters"]: - if arg in self.redshift_data_api_kwargs.keys(): + if arg in self.redshift_data_api_kwargs: raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs") def _build_unload_query( diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index b80b5bed3276..6bedb092b42b 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -116,7 +116,7 @@ def __init__( if self.redshift_data_api_kwargs: for arg in ["sql", "parameters"]: - if arg in self.redshift_data_api_kwargs.keys(): + if arg in self.redshift_data_api_kwargs: raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs") def _build_copy_query( diff --git a/airflow/providers/amazon/aws/transfers/sql_to_s3.py b/airflow/providers/amazon/aws/transfers/sql_to_s3.py index 0324406820e2..1b327def97c9 100644 --- a/airflow/providers/amazon/aws/transfers/sql_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/sql_to_s3.py @@ -194,7 +194,7 @@ def _partition_dataframe(self, df: DataFrame) -> Iterable[tuple[str, DataFrame]] yield "", df else: grouped_df = df.groupby(**self.groupby_kwargs) - for group_label in grouped_df.groups.keys(): + for group_label in grouped_df.groups: yield group_label, grouped_df.get_group(group_label).reset_index(drop=True) def _get_hook(self) -> DbApiHook: diff --git a/airflow/providers/amazon/aws/triggers/batch.py b/airflow/providers/amazon/aws/triggers/batch.py index 8966540a3849..40b230221fc0 100644 --- a/airflow/providers/amazon/aws/triggers/batch.py +++ b/airflow/providers/amazon/aws/triggers/batch.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio +import itertools as it from functools import cached_property from typing import Any @@ -161,9 +162,7 @@ async def run(self): """ async with self.hook.async_conn as client: waiter = self.hook.get_waiter("batch_job_complete", deferrable=True, client=client) - attempt = 0 - while True: - attempt = attempt + 1 + for attempt in it.count(1): try: await waiter.wait( jobs=[self.job_id], @@ -172,7 +171,6 @@ async def run(self): "MaxAttempts": 1, }, ) - break except WaiterError as error: if "error" in str(error): yield TriggerEvent({"status": "failure", "message": f"Job Failed: {error}"}) @@ -183,6 +181,8 @@ async def run(self): attempt, ) await asyncio.sleep(int(self.poke_interval)) + else: + break yield TriggerEvent( { diff --git a/airflow/providers/amazon/aws/triggers/ecs.py b/airflow/providers/amazon/aws/triggers/ecs.py index cc203207d158..4a380d23f0fa 100644 --- a/airflow/providers/amazon/aws/triggers/ecs.py +++ b/airflow/providers/amazon/aws/triggers/ecs.py @@ -165,8 +165,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # fmt: on waiter = ecs_client.get_waiter("tasks_stopped") logs_token = None - while self.waiter_max_attempts >= 1: - self.waiter_max_attempts = self.waiter_max_attempts - 1 + while self.waiter_max_attempts: + self.waiter_max_attempts -= 1 try: await waiter.wait( cluster=self.cluster, tasks=[self.task_arn], WaiterConfig={"MaxAttempts": 1} diff --git a/airflow/providers/amazon/aws/triggers/emr.py b/airflow/providers/amazon/aws/triggers/emr.py index d7d9844af33f..af8fef5f7bcf 100644 --- a/airflow/providers/amazon/aws/triggers/emr.py +++ b/airflow/providers/amazon/aws/triggers/emr.py @@ -70,10 +70,8 @@ async def run(self): self.hook = EmrHook(aws_conn_id=self.aws_conn_id) async with self.hook.async_conn as client: for step_id in self.step_ids: - attempt = 0 waiter = client.get_waiter("step_complete") - while attempt < int(self.max_attempts): - attempt += 1 + for attempt in range(1, 1 + self.max_attempts): try: await waiter.wait( ClusterId=self.job_flow_id, diff --git a/airflow/providers/amazon/aws/triggers/sagemaker.py b/airflow/providers/amazon/aws/triggers/sagemaker.py index ec11323d426e..7165276c4c91 100644 --- a/airflow/providers/amazon/aws/triggers/sagemaker.py +++ b/airflow/providers/amazon/aws/triggers/sagemaker.py @@ -164,12 +164,10 @@ def serialize(self) -> tuple[str, dict[str, Any]]: } async def run(self) -> AsyncIterator[TriggerEvent]: - attempts = 0 hook = SageMakerHook(aws_conn_id=self.aws_conn_id) async with hook.async_conn as conn: waiter = hook.get_waiter(self._waiter_name[self.waiter_type], deferrable=True, client=conn) - while attempts < self.waiter_max_attempts: - attempts = attempts + 1 + for _ in range(self.waiter_max_attempts): try: await waiter.wait( PipelineExecutionArn=self.pipeline_execution_arn, WaiterConfig={"MaxAttempts": 1} diff --git a/airflow/providers/amazon/aws/utils/waiter_with_logging.py b/airflow/providers/amazon/aws/utils/waiter_with_logging.py index 1e506927a1f4..2c0d40342653 100644 --- a/airflow/providers/amazon/aws/utils/waiter_with_logging.py +++ b/airflow/providers/amazon/aws/utils/waiter_with_logging.py @@ -63,22 +63,21 @@ def wait( status_args = ["Clusters[0].state", "Clusters[0].details"] """ log = logging.getLogger(__name__) - attempt = 0 - while True: - attempt += 1 + for attempt in range(waiter_max_attempts): + if attempt: + time.sleep(waiter_delay) try: waiter.wait(**args, WaiterConfig={"MaxAttempts": 1}) - break except WaiterError as error: if "terminal failure" in str(error): log.error("%s: %s", failure_message, _LazyStatusFormatter(status_args, error.last_response)) raise AirflowException(f"{failure_message}: {error}") log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, error.last_response)) - if attempt >= waiter_max_attempts: - raise AirflowException("Waiter error: max attempts reached") - - time.sleep(waiter_delay) + else: + break + else: + raise AirflowException("Waiter error: max attempts reached") async def async_wait( @@ -115,22 +114,21 @@ async def async_wait( status_args = ["Clusters[0].state", "Clusters[0].details"] """ log = logging.getLogger(__name__) - attempt = 0 - while True: - attempt += 1 + for attempt in range(waiter_max_attempts): + if attempt: + await asyncio.sleep(waiter_delay) try: await waiter.wait(**args, WaiterConfig={"MaxAttempts": 1}) - break except WaiterError as error: if "terminal failure" in str(error): log.error("%s: %s", failure_message, _LazyStatusFormatter(status_args, error.last_response)) raise AirflowException(f"{failure_message}: {error}") log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, error.last_response)) - if attempt >= waiter_max_attempts: - raise AirflowException("Waiter error: max attempts reached") - - await asyncio.sleep(waiter_delay) + else: + break + else: + raise AirflowException("Waiter error: max attempts reached") class _LazyStatusFormatter: