From 9a9be62fae4812644c2d9cd243499505aaa7d471 Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Wed, 25 May 2022 09:13:54 +1000 Subject: [PATCH 1/8] Various clean-up --- airflow/providers/amazon/aws/hooks/athena.py | 3 +- airflow/providers/amazon/aws/hooks/glue.py | 11 +++---- .../amazon/aws/hooks/glue_crawler.py | 29 ++++++++-------- airflow/providers/amazon/aws/hooks/kinesis.py | 4 +-- .../amazon/aws/hooks/lambda_function.py | 6 ++-- .../amazon/aws/hooks/redshift_sql.py | 4 +-- airflow/providers/amazon/aws/hooks/s3.py | 13 +++----- .../amazon/aws/log/s3_task_handler.py | 4 +-- .../providers/amazon/aws/operators/glacier.py | 3 +- .../amazon/aws/secrets/secrets_manager.py | 33 +++++++++---------- .../amazon/aws/secrets/systems_manager.py | 3 +- airflow/providers/amazon/aws/sensors/emr.py | 2 +- airflow/providers/amazon/aws/sensors/glue.py | 2 +- .../providers/amazon/aws/sensors/sagemaker.py | 4 +-- .../amazon/aws/transfers/google_api_to_s3.py | 3 +- .../amazon/aws/transfers/mysql_to_s3.py | 5 ++- .../amazon/aws/utils/eks_get_token.py | 5 +-- 17 files changed, 57 insertions(+), 77 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/athena.py b/airflow/providers/amazon/aws/hooks/athena.py index bbd08f9034529..82e69ccf6f001 100644 --- a/airflow/providers/amazon/aws/hooks/athena.py +++ b/airflow/providers/amazon/aws/hooks/athena.py @@ -91,8 +91,7 @@ def run_query( if client_request_token: params['ClientRequestToken'] = client_request_token response = self.get_conn().start_query_execution(**params) - query_execution_id = response['QueryExecutionId'] - return query_execution_id + return response['QueryExecutionId'] def check_query_status(self, query_execution_id: str) -> Optional[str]: """ diff --git a/airflow/providers/amazon/aws/hooks/glue.py b/airflow/providers/amazon/aws/hooks/glue.py index f0170f358f73a..6b652e9fc7b75 100644 --- a/airflow/providers/amazon/aws/hooks/glue.py +++ b/airflow/providers/amazon/aws/hooks/glue.py @@ -74,7 +74,7 @@ def __init__( raise ValueError("Cannot specify num_of_dpus with custom WorkerType") elif not worker_type_exists and num_workers_exists: raise ValueError("Need to specify custom WorkerType when specifying NumberOfWorkers") - elif worker_type_exists and not num_workers_exists: + elif worker_type_exists: raise ValueError("Need to specify NumberOfWorkers when specifying custom WorkerType") elif num_of_dpus is None: self.num_of_dpus = 10 @@ -118,8 +118,8 @@ def initialize_job( try: job_name = self.get_or_create_glue_job() - job_run = glue_client.start_job_run(JobName=job_name, Arguments=script_arguments, **run_kwargs) - return job_run + return glue_client.start_job_run(JobName=job_name, Arguments=script_arguments, **run_kwargs) + except Exception as general_error: self.log.error("Failed to run aws glue job, error: %s", general_error) raise @@ -134,8 +134,7 @@ def get_job_state(self, job_name: str, run_id: str) -> str: """ glue_client = self.get_conn() job_run = glue_client.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True) - job_run_state = job_run['JobRun']['JobRunState'] - return job_run_state + return job_run['JobRun']['JobRunState'] def job_completion(self, job_name: str, run_id: str) -> Dict[str, str]: """ @@ -155,7 +154,7 @@ def job_completion(self, job_name: str, run_id: str) -> Dict[str, str]: self.log.info("Exiting Job %s Run State: %s", run_id, job_run_state) return {'JobRunState': job_run_state, 'JobRunId': run_id} if job_run_state in failed_states: - job_error_message = "Exiting Job " + run_id + " Run State: " + job_run_state + job_error_message = f"Exiting Job {run_id} Run State: {job_run_state}" self.log.info(job_error_message) raise AirflowException(job_error_message) else: diff --git a/airflow/providers/amazon/aws/hooks/glue_crawler.py b/airflow/providers/amazon/aws/hooks/glue_crawler.py index 00d438aaf56aa..65f7df8d28566 100644 --- a/airflow/providers/amazon/aws/hooks/glue_crawler.py +++ b/airflow/providers/amazon/aws/hooks/glue_crawler.py @@ -102,8 +102,7 @@ def create_crawler(self, **crawler_kwargs) -> str: """ crawler_name = crawler_kwargs['Name'] self.log.info("Creating crawler: %s", crawler_name) - crawler = self.glue_client.create_crawler(**crawler_kwargs) - return crawler + return self.glue_client.create_crawler(**crawler_kwargs) def start_crawler(self, crawler_name: str) -> dict: """ @@ -113,8 +112,7 @@ def start_crawler(self, crawler_name: str) -> dict: :return: Empty dictionary """ self.log.info("Starting crawler %s", crawler_name) - crawler = self.glue_client.start_crawler(Name=crawler_name) - return crawler + return self.glue_client.start_crawler(Name=crawler_name) def wait_for_crawler_completion(self, crawler_name: str, poll_interval: int = 5) -> str: """ @@ -137,18 +135,17 @@ def wait_for_crawler_completion(self, crawler_name: str, poll_interval: int = 5) crawler_status = crawler['LastCrawl']['Status'] if crawler_status in failed_status: raise AirflowException(f"Status: {crawler_status}") - else: - metrics = self.glue_client.get_crawler_metrics(CrawlerNameList=[crawler_name])[ - 'CrawlerMetricsList' - ][0] - self.log.info("Status: %s", crawler_status) - self.log.info("Last Runtime Duration (seconds): %s", metrics['LastRuntimeSeconds']) - self.log.info("Median Runtime Duration (seconds): %s", metrics['MedianRuntimeSeconds']) - self.log.info("Tables Created: %s", metrics['TablesCreated']) - self.log.info("Tables Updated: %s", metrics['TablesUpdated']) - self.log.info("Tables Deleted: %s", metrics['TablesDeleted']) - - return crawler_status + metrics = self.glue_client.get_crawler_metrics(CrawlerNameList=[crawler_name])[ + 'CrawlerMetricsList' + ][0] + self.log.info("Status: %s", crawler_status) + self.log.info("Last Runtime Duration (seconds): %s", metrics['LastRuntimeSeconds']) + self.log.info("Median Runtime Duration (seconds): %s", metrics['MedianRuntimeSeconds']) + self.log.info("Tables Created: %s", metrics['TablesCreated']) + self.log.info("Tables Updated: %s", metrics['TablesUpdated']) + self.log.info("Tables Deleted: %s", metrics['TablesDeleted']) + + return crawler_status else: self.log.info("Polling for AWS Glue crawler: %s ", crawler_name) diff --git a/airflow/providers/amazon/aws/hooks/kinesis.py b/airflow/providers/amazon/aws/hooks/kinesis.py index 8f26b54d64b12..f15457c6c28d7 100644 --- a/airflow/providers/amazon/aws/hooks/kinesis.py +++ b/airflow/providers/amazon/aws/hooks/kinesis.py @@ -43,9 +43,7 @@ def __init__(self, delivery_stream: str, *args, **kwargs) -> None: def put_records(self, records: Iterable): """Write batch records to Kinesis Firehose""" - response = self.get_conn().put_record_batch(DeliveryStreamName=self.delivery_stream, Records=records) - - return response + return self.get_conn().put_record_batch(DeliveryStreamName=self.delivery_stream, Records=records) class AwsFirehoseHook(FirehoseHook): diff --git a/airflow/providers/amazon/aws/hooks/lambda_function.py b/airflow/providers/amazon/aws/hooks/lambda_function.py index 0f04af7e3b57b..b6819d9dba263 100644 --- a/airflow/providers/amazon/aws/hooks/lambda_function.py +++ b/airflow/providers/amazon/aws/hooks/lambda_function.py @@ -66,8 +66,7 @@ def invoke_lambda( "Payload": payload, "Qualifier": qualifier, } - response = self.conn.invoke(**{k: v for k, v in invoke_args.items() if v is not None}) - return response + return self.conn.invoke(**{k: v for k, v in invoke_args.items() if v is not None}) def create_lambda( self, @@ -118,10 +117,9 @@ def create_lambda( "CodeSigningConfigArn": code_signing_config_arn, "Architectures": architectures, } - response = self.conn.create_function( + return self.conn.create_function( **{k: v for k, v in create_function_args.items() if v is not None}, ) - return response class AwsLambdaHook(LambdaHook): diff --git a/airflow/providers/amazon/aws/hooks/redshift_sql.py b/airflow/providers/amazon/aws/hooks/redshift_sql.py index 0b889063cc10a..03bb45f7ee128 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_sql.py +++ b/airflow/providers/amazon/aws/hooks/redshift_sql.py @@ -130,6 +130,4 @@ def get_conn(self) -> RedshiftConnection: conn_params = self._get_conn_params() conn_kwargs_dejson = self.conn.extra_dejson conn_kwargs: Dict = {**conn_params, **conn_kwargs_dejson} - conn: RedshiftConnection = redshift_connector.connect(**conn_kwargs) - - return conn + return redshift_connector.connect(**conn_kwargs) diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 71ef78aabe3ae..e7e9f2de508da 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -277,11 +277,10 @@ def list_prefixes( Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config ) - prefixes = [] + prefixes = [] # type: List[str] for page in response: if 'CommonPrefixes' in page: - for common_prefix in page['CommonPrefixes']: - prefixes.append(common_prefix['Prefix']) + prefixes.extend(common_prefix['Prefix'] for common_prefix in page['CommonPrefixes']) return prefixes @@ -366,12 +365,10 @@ def _is_in_period(input_date: datetime) -> bool: StartAfter=start_after_key, ) - keys = [] + keys = [] # type: List[str] for page in response: if 'Contents' in page: - for k in page['Contents']: - keys.append(k) - + keys.extend(iter(page['Contents'])) if self.object_filter_usr is not None: return self.object_filter_usr(keys, from_datetime, to_datetime) @@ -604,7 +601,7 @@ def load_file( extra_args['ServerSideEncryption'] = "AES256" if gzip: with open(filename, 'rb') as f_in: - filename_gz = f_in.name + '.gz' + filename_gz = f'{f_in.name}.gz' with gz.open(filename_gz, 'wb') as f_out: shutil.copyfileobj(f_in, f_out) filename = filename_gz diff --git a/airflow/providers/amazon/aws/log/s3_task_handler.py b/airflow/providers/amazon/aws/log/s3_task_handler.py index 695c4623d97b2..ce3da88f16916 100644 --- a/airflow/providers/amazon/aws/log/s3_task_handler.py +++ b/airflow/providers/amazon/aws/log/s3_task_handler.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. import os +import pathlib import sys if sys.version_info >= (3, 8): @@ -92,8 +93,7 @@ def close(self): remote_loc = os.path.join(self.remote_base, self.log_relative_path) if os.path.exists(local_loc): # read log and remove old logs to get just the latest additions - with open(local_loc) as logfile: - log = logfile.read() + log = pathlib.Path(local_loc).read_text() self.s3_write(log, remote_loc) # Mark closed so we don't double write if close is called twice diff --git a/airflow/providers/amazon/aws/operators/glacier.py b/airflow/providers/amazon/aws/operators/glacier.py index 27904e4d1d0a3..337492a4523e1 100644 --- a/airflow/providers/amazon/aws/operators/glacier.py +++ b/airflow/providers/amazon/aws/operators/glacier.py @@ -51,5 +51,4 @@ def __init__( def execute(self, context: 'Context'): hook = GlacierHook(aws_conn_id=self.aws_conn_id) - response = hook.retrieve_inventory(vault_name=self.vault_name) - return response + return hook.retrieve_inventory(vault_name=self.vault_name) diff --git a/airflow/providers/amazon/aws/secrets/secrets_manager.py b/airflow/providers/amazon/aws/secrets/secrets_manager.py index cd23372f878fd..8b72f955ac6d5 100644 --- a/airflow/providers/amazon/aws/secrets/secrets_manager.py +++ b/airflow/providers/amazon/aws/secrets/secrets_manager.py @@ -134,7 +134,7 @@ def __init__( self.profile_name = profile_name self.sep = sep self.full_url_mode = full_url_mode - self.extra_conn_words = extra_conn_words if extra_conn_words else {} + self.extra_conn_words = extra_conn_words or {} self.kwargs = kwargs @cached_property @@ -178,9 +178,7 @@ def get_uri_from_secret(self, secret): conn_string = "{conn_type}://{user}:{password}@{host}:{port}/{schema}".format(**conn_d) - connection = self._format_uri_with_extra(secret, conn_string) - - return connection + return self._format_uri_with_extra(secret, conn_string) def get_conn_value(self, conn_id: str): """ @@ -193,20 +191,19 @@ def get_conn_value(self, conn_id: str): if self.full_url_mode: return self._get_secret(self.connections_prefix, conn_id) - else: - try: - secret_string = self._get_secret(self.connections_prefix, conn_id) - # json.loads gives error - secret = ast.literal_eval(secret_string) if secret_string else None - except ValueError: # 'malformed node or string: ' error, for empty conns - connection = None - secret = None - - # These lines will check if we have with some denomination stored an username, password and host - if secret: - connection = self.get_uri_from_secret(secret) - - return connection + try: + secret_string = self._get_secret(self.connections_prefix, conn_id) + # json.loads gives error + secret = ast.literal_eval(secret_string) if secret_string else None + except ValueError: # 'malformed node or string: ' error, for empty conns + connection = None + secret = None + + # These lines will check if we have with some denomination stored an username, password and host + if secret: + connection = self.get_uri_from_secret(secret) + + return connection def get_conn_uri(self, conn_id: str) -> Optional[str]: """ diff --git a/airflow/providers/amazon/aws/secrets/systems_manager.py b/airflow/providers/amazon/aws/secrets/systems_manager.py index 310e3337bcf1e..e45a5500ab003 100644 --- a/airflow/providers/amazon/aws/secrets/systems_manager.py +++ b/airflow/providers/amazon/aws/secrets/systems_manager.py @@ -158,8 +158,7 @@ def _get_secret(self, path_prefix: str, secret_id: str) -> Optional[str]: ssm_path = self.build_path(path_prefix, secret_id) try: response = self.client.get_parameter(Name=ssm_path, WithDecryption=True) - value = response["Parameter"]["Value"] - return value + return response["Parameter"]["Value"] except self.client.exceptions.ParameterNotFound: self.log.debug("Parameter %s not found.", ssm_path) return None diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index c1f4a449a4e6a..a4c2b3a71142b 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -66,7 +66,7 @@ def get_hook(self) -> EmrHook: def poke(self, context: 'Context'): response = self.get_emr_response() - if not response['ResponseMetadata']['HTTPStatusCode'] == 200: + if response['ResponseMetadata']['HTTPStatusCode'] != 200: self.log.info('Bad HTTP response: %s', response) return False diff --git a/airflow/providers/amazon/aws/sensors/glue.py b/airflow/providers/amazon/aws/sensors/glue.py index 59bdaa4c4eff3..525e7b8ee6234 100644 --- a/airflow/providers/amazon/aws/sensors/glue.py +++ b/airflow/providers/amazon/aws/sensors/glue.py @@ -57,7 +57,7 @@ def poke(self, context: 'Context'): self.log.info("Exiting Job %s Run State: %s", self.run_id, job_state) return True elif job_state in self.errored_states: - job_error_message = "Exiting Job " + self.run_id + " Run State: " + job_state + job_error_message = f"Exiting Job {self.run_id} Run State: {job_state}" raise AirflowException(job_error_message) else: return False diff --git a/airflow/providers/amazon/aws/sensors/sagemaker.py b/airflow/providers/amazon/aws/sensors/sagemaker.py index 3cf6dceef154c..925ddaed17fd9 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker.py @@ -50,7 +50,7 @@ def get_hook(self) -> SageMakerHook: def poke(self, context: 'Context'): response = self.get_sagemaker_response() - if not (response['ResponseMetadata']['HTTPStatusCode'] == 200): + if response['ResponseMetadata']['HTTPStatusCode'] != 200: self.log.info('Bad HTTP response: %s', response) return False state = self.state_from_response(response) @@ -225,7 +225,7 @@ def init_log_resource(self, hook: SageMakerHook) -> None: self.instance_count = description['ResourceConfig']['InstanceCount'] status = description['TrainingJobStatus'] job_already_completed = status not in self.non_terminal_states() - self.state = LogState.TAILING if (not job_already_completed) else LogState.COMPLETE + self.state = LogState.COMPLETE if job_already_completed else LogState.TAILING self.last_description = description self.last_describe_job_call = time.monotonic() self.log_resource_inited = True diff --git a/airflow/providers/amazon/aws/transfers/google_api_to_s3.py b/airflow/providers/amazon/aws/transfers/google_api_to_s3.py index 34b590cb2d209..f3e10b62b2aa0 100644 --- a/airflow/providers/amazon/aws/transfers/google_api_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/google_api_to_s3.py @@ -155,13 +155,12 @@ def _retrieve_data_from_google_api(self) -> dict: api_version=self.google_api_service_version, impersonation_chain=self.google_impersonation_chain, ) - google_api_response = google_discovery_api_hook.query( + return google_discovery_api_hook.query( endpoint=self.google_api_endpoint_path, data=self.google_api_endpoint_params, paginate=self.google_api_pagination, num_retries=self.google_api_num_retries, ) - return google_api_response def _load_data_to_s3(self, data: dict) -> None: s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) diff --git a/airflow/providers/amazon/aws/transfers/mysql_to_s3.py b/airflow/providers/amazon/aws/transfers/mysql_to_s3.py index 728aaddcba0a9..dc3d84ecb3658 100644 --- a/airflow/providers/amazon/aws/transfers/mysql_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/mysql_to_s3.py @@ -66,8 +66,7 @@ def __init__( if "header" not in pd_kwargs: pd_kwargs["header"] = header kwargs["pd_kwargs"] = {**kwargs.get('pd_kwargs', {}), **pd_kwargs} - else: - if pd_csv_kwargs is not None: - raise TypeError("pd_csv_kwargs may not be specified when file_format='parquet'") + elif pd_csv_kwargs is not None: + raise TypeError("pd_csv_kwargs may not be specified when file_format='parquet'") super().__init__(sql_conn_id=mysql_conn_id, **kwargs) diff --git a/airflow/providers/amazon/aws/utils/eks_get_token.py b/airflow/providers/amazon/aws/utils/eks_get_token.py index 145b8417bbd1e..d9422b35e301c 100644 --- a/airflow/providers/amazon/aws/utils/eks_get_token.py +++ b/airflow/providers/amazon/aws/utils/eks_get_token.py @@ -17,7 +17,7 @@ import argparse import json -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from airflow.providers.amazon.aws.hooks.eks import EksHook @@ -27,7 +27,8 @@ def get_expiration_time(): - token_expiration = datetime.utcnow() + timedelta(minutes=TOKEN_EXPIRATION_MINUTES) + token_expiration = datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRATION_MINUTES) + return token_expiration.strftime('%Y-%m-%dT%H:%M:%SZ') From 8112f45c14d96f9932d746df1fa67c4dbd73dd3c Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Wed, 25 May 2022 09:21:49 +1000 Subject: [PATCH 2/8] Update eks_test_utils.py --- .../amazon/aws/utils/eks_test_utils.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tests/providers/amazon/aws/utils/eks_test_utils.py b/tests/providers/amazon/aws/utils/eks_test_utils.py index 4fa2b7b7662fa..01ba4ac779683 100644 --- a/tests/providers/amazon/aws/utils/eks_test_utils.py +++ b/tests/providers/amazon/aws/utils/eks_test_utils.py @@ -64,7 +64,7 @@ def attributes_to_test( # The below tag is mandatory and must have a value of either 'owned' or 'shared' # A value of 'owned' denotes that the subnets are exclusive to the nodegroup. # The 'shared' value allows more than one resource to use the subnet. - required_tag: Dict = {'kubernetes.io/cluster/' + cluster_name: 'owned'} + required_tag: Dict = {f'kubernetes.io/cluster/{cluster_name}': 'owned'} # Find the user-submitted tag set and append the required tag to it. final_tag_set: Dict = required_tag for key, value in result: @@ -93,7 +93,7 @@ def generate_clusters(eks_hook: EksHook, num_clusters: int, minimal: bool) -> Li """ # Generates N clusters named cluster0, cluster1, .., clusterN return [ - eks_hook.create_cluster(name="cluster" + str(count), **_input_builder(ClusterInputs, minimal))[ + eks_hook.create_cluster(name=f"cluster{str(count)}", **_input_builder(ClusterInputs, minimal))[ ResponseAttributes.CLUSTER ][ClusterAttributes.NAME] for count in range(num_clusters) @@ -116,7 +116,7 @@ def generate_fargate_profiles( # Generates N Fargate profiles named profile0, profile1, .., profileN return [ eks_hook.create_fargate_profile( - fargateProfileName="profile" + str(count), + fargateProfileName=f"profile{str(count)}", clusterName=cluster_name, **_input_builder(FargateProfileInputs, minimal), )[ResponseAttributes.FARGATE_PROFILE][FargateProfileAttributes.FARGATE_PROFILE_NAME] @@ -140,7 +140,7 @@ def generate_nodegroups( # Generates N nodegroups named nodegroup0, nodegroup1, .., nodegroupN return [ eks_hook.create_nodegroup( - nodegroupName="nodegroup" + str(count), + nodegroupName=f"nodegroup{str(count)}", clusterName=cluster_name, **_input_builder(NodegroupInputs, minimal), )[ResponseAttributes.NODEGROUP][NodegroupAttributes.NODEGROUP_NAME] @@ -164,10 +164,14 @@ def region_matches_partition(region: str, partition: str) -> bool: ("us-gov-iso-b-", "aws-iso-b"), ] - for prefix, expected_partition in valid_matches: - if region.startswith(prefix): - return partition == expected_partition - return partition == "aws" + return next( + ( + partition == expected_partition + for prefix, expected_partition in valid_matches + if region.startswith(prefix) + ), + partition == "aws", + ) def _input_builder(options: InputTypes, minimal: bool) -> Dict: From 17263d6d689a6d3ece599fda9ecbf00bf8c2491f Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Wed, 25 May 2022 09:25:17 +1000 Subject: [PATCH 3/8] f-string concat in tests --- tests/providers/amazon/aws/transfers/test_s3_to_sftp.py | 2 +- tests/providers/amazon/aws/transfers/test_sftp_to_s3.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py b/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py index ddd4ca378e149..f7e8b5f4f4971 100644 --- a/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py +++ b/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py @@ -50,7 +50,7 @@ def setUp(self): s3_hook = S3Hook('aws_default') hook.no_host_key_check = True dag = DAG( - TEST_DAG_ID + 'test_schedule_dag_once', + f'{TEST_DAG_ID}test_schedule_dag_once', start_date=DEFAULT_DATE, schedule_interval='@once', ) diff --git a/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py b/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py index b24c89b7f8d59..e5dcbc7254c80 100644 --- a/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py @@ -51,7 +51,7 @@ def setUp(self): s3_hook = S3Hook('aws_default') hook.no_host_key_check = True dag = DAG( - TEST_DAG_ID + 'test_schedule_dag_once', + f'{TEST_DAG_ID}test_schedule_dag_once', schedule_interval="@once", start_date=DEFAULT_DATE, ) From 35a2f0a1718674095b822c545af971d7e25a87dc Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Wed, 25 May 2022 09:29:53 +1000 Subject: [PATCH 4/8] Test cleanup --- tests/providers/amazon/aws/log/test_s3_task_handler.py | 5 ++--- tests/providers/amazon/aws/operators/test_athena.py | 3 ++- .../amazon/aws/operators/test_dms_describe_tasks.py | 6 +----- .../providers/amazon/aws/sensors/test_s3_keys_unchanged.py | 5 ++--- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/tests/providers/amazon/aws/log/test_s3_task_handler.py b/tests/providers/amazon/aws/log/test_s3_task_handler.py index a322f167eccc2..b647f95ba4360 100644 --- a/tests/providers/amazon/aws/log/test_s3_task_handler.py +++ b/tests/providers/amazon/aws/log/test_s3_task_handler.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. +import contextlib import os import unittest from unittest import mock @@ -75,10 +76,8 @@ def setUp(self): def tearDown(self): if self.s3_task_handler.handler: - try: + with contextlib.suppress(Exception): os.remove(self.s3_task_handler.handler.baseFilename) - except Exception: - pass def test_hook(self): assert isinstance(self.s3_task_handler.hook, S3Hook) diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index a1cfc8478d9e8..2ec96495e6537 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -50,7 +50,8 @@ def setUp(self): 'start_date': DEFAULT_DATE, } - self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args, schedule_interval='@once') + self.dag = DAG(f'{TEST_DAG_ID}test_schedule_dag_once', default_args=args, schedule_interval='@once') + self.athena = AthenaOperator( task_id='test_athena_operator', query='SELECT * FROM TEST_TABLE', diff --git a/tests/providers/amazon/aws/operators/test_dms_describe_tasks.py b/tests/providers/amazon/aws/operators/test_dms_describe_tasks.py index ebf4c11989d86..3c7c587d46153 100644 --- a/tests/providers/amazon/aws/operators/test_dms_describe_tasks.py +++ b/tests/providers/amazon/aws/operators/test_dms_describe_tasks.py @@ -57,11 +57,7 @@ def setUp(self): "start_date": DEFAULT_DATE, } - self.dag = DAG( - TEST_DAG_ID + "test_schedule_dag_once", - default_args=args, - schedule_interval="@once", - ) + self.dag = DAG(f"{TEST_DAG_ID}test_schedule_dag_once", default_args=args, schedule_interval="@once") def test_init(self): dms_operator = DmsDescribeTasksOperator( diff --git a/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py b/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py index 1c1d85242b330..f7767f81e0aeb 100644 --- a/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py +++ b/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py @@ -33,10 +33,9 @@ class TestS3KeysUnchangedSensor(TestCase): def setUp(self): self.dag = DAG( - TEST_DAG_ID + 'test_schedule_dag_once', - start_date=DEFAULT_DATE, - schedule_interval="@once", + f'{TEST_DAG_ID}test_schedule_dag_once', start_date=DEFAULT_DATE, schedule_interval="@once" ) + self.sensor = S3KeysUnchangedSensor( task_id='sensor_1', bucket_name='test-bucket', From 5957b61a0fcfee30c04f68d6cb7e36b1c1d4645d Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Wed, 25 May 2022 09:31:43 +1000 Subject: [PATCH 5/8] Update test_cloud_formation.py --- tests/providers/amazon/aws/hooks/test_cloud_formation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/amazon/aws/hooks/test_cloud_formation.py b/tests/providers/amazon/aws/hooks/test_cloud_formation.py index 35417e5e8045b..4739def7725c1 100644 --- a/tests/providers/amazon/aws/hooks/test_cloud_formation.py +++ b/tests/providers/amazon/aws/hooks/test_cloud_formation.py @@ -102,4 +102,4 @@ def test_delete_stack(self): stacks = self.hook.get_conn().describe_stacks()['Stacks'] matching_stacks = [x for x in stacks if x['StackName'] == stack_name] - assert len(matching_stacks) == 0, f'stack with name {stack_name} should not exist' + assert not matching_stacks, f'stack with name {stack_name} should not exist' From cdd29e5ce55aee76f033585f520c3e6076182b15 Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Thu, 26 May 2022 09:24:16 +1000 Subject: [PATCH 6/8] Update glue.py --- airflow/providers/amazon/aws/hooks/glue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/hooks/glue.py b/airflow/providers/amazon/aws/hooks/glue.py index 6b652e9fc7b75..fc485931a2d90 100644 --- a/airflow/providers/amazon/aws/hooks/glue.py +++ b/airflow/providers/amazon/aws/hooks/glue.py @@ -74,7 +74,7 @@ def __init__( raise ValueError("Cannot specify num_of_dpus with custom WorkerType") elif not worker_type_exists and num_workers_exists: raise ValueError("Need to specify custom WorkerType when specifying NumberOfWorkers") - elif worker_type_exists: + elif worker_type_exists and not num_workers_exists:: raise ValueError("Need to specify NumberOfWorkers when specifying custom WorkerType") elif num_of_dpus is None: self.num_of_dpus = 10 From 50642bd832966b98dd1f5e9461fb3cbdd3b2aa3c Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Thu, 26 May 2022 09:25:49 +1000 Subject: [PATCH 7/8] Update eks_test_utils.py --- tests/providers/amazon/aws/utils/eks_test_utils.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/providers/amazon/aws/utils/eks_test_utils.py b/tests/providers/amazon/aws/utils/eks_test_utils.py index 01ba4ac779683..235658f777095 100644 --- a/tests/providers/amazon/aws/utils/eks_test_utils.py +++ b/tests/providers/amazon/aws/utils/eks_test_utils.py @@ -164,14 +164,10 @@ def region_matches_partition(region: str, partition: str) -> bool: ("us-gov-iso-b-", "aws-iso-b"), ] - return next( - ( - partition == expected_partition - for prefix, expected_partition in valid_matches - if region.startswith(prefix) - ), - partition == "aws", - ) + for prefix, expected_partition in valid_matches: + if region.startswith(prefix): + return partition == expected_partition + return partition == "aws" def _input_builder(options: InputTypes, minimal: bool) -> Dict: From 8a89c284026f2f3ea8c3e26a5024fcff451e8667 Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Fri, 27 May 2022 13:59:18 +1000 Subject: [PATCH 8/8] Update glue.py --- airflow/providers/amazon/aws/hooks/glue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/hooks/glue.py b/airflow/providers/amazon/aws/hooks/glue.py index fc485931a2d90..dcd6d7c4661cd 100644 --- a/airflow/providers/amazon/aws/hooks/glue.py +++ b/airflow/providers/amazon/aws/hooks/glue.py @@ -74,7 +74,7 @@ def __init__( raise ValueError("Cannot specify num_of_dpus with custom WorkerType") elif not worker_type_exists and num_workers_exists: raise ValueError("Need to specify custom WorkerType when specifying NumberOfWorkers") - elif worker_type_exists and not num_workers_exists:: + elif worker_type_exists and not num_workers_exists: raise ValueError("Need to specify NumberOfWorkers when specifying custom WorkerType") elif num_of_dpus is None: self.num_of_dpus = 10