Skip to content

Commit

Permalink
Light Refactor and Clean-up AWS Provider (#23907)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincent Koc authored May 31, 2022
1 parent ab1f637 commit 595981c
Show file tree
Hide file tree
Showing 25 changed files with 70 additions and 95 deletions.
3 changes: 1 addition & 2 deletions airflow/providers/amazon/aws/hooks/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def run_query(
if client_request_token:
params['ClientRequestToken'] = client_request_token
response = self.get_conn().start_query_execution(**params)
query_execution_id = response['QueryExecutionId']
return query_execution_id
return response['QueryExecutionId']

def check_query_status(self, query_execution_id: str) -> Optional[str]:
"""
Expand Down
9 changes: 4 additions & 5 deletions airflow/providers/amazon/aws/hooks/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def initialize_job(

try:
job_name = self.get_or_create_glue_job()
job_run = glue_client.start_job_run(JobName=job_name, Arguments=script_arguments, **run_kwargs)
return job_run
return glue_client.start_job_run(JobName=job_name, Arguments=script_arguments, **run_kwargs)

except Exception as general_error:
self.log.error("Failed to run aws glue job, error: %s", general_error)
raise
Expand All @@ -134,8 +134,7 @@ def get_job_state(self, job_name: str, run_id: str) -> str:
"""
glue_client = self.get_conn()
job_run = glue_client.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True)
job_run_state = job_run['JobRun']['JobRunState']
return job_run_state
return job_run['JobRun']['JobRunState']

def job_completion(self, job_name: str, run_id: str) -> Dict[str, str]:
"""
Expand All @@ -155,7 +154,7 @@ def job_completion(self, job_name: str, run_id: str) -> Dict[str, str]:
self.log.info("Exiting Job %s Run State: %s", run_id, job_run_state)
return {'JobRunState': job_run_state, 'JobRunId': run_id}
if job_run_state in failed_states:
job_error_message = "Exiting Job " + run_id + " Run State: " + job_run_state
job_error_message = f"Exiting Job {run_id} Run State: {job_run_state}"
self.log.info(job_error_message)
raise AirflowException(job_error_message)
else:
Expand Down
29 changes: 13 additions & 16 deletions airflow/providers/amazon/aws/hooks/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ def create_crawler(self, **crawler_kwargs) -> str:
"""
crawler_name = crawler_kwargs['Name']
self.log.info("Creating crawler: %s", crawler_name)
crawler = self.glue_client.create_crawler(**crawler_kwargs)
return crawler
return self.glue_client.create_crawler(**crawler_kwargs)

def start_crawler(self, crawler_name: str) -> dict:
"""
Expand All @@ -113,8 +112,7 @@ def start_crawler(self, crawler_name: str) -> dict:
:return: Empty dictionary
"""
self.log.info("Starting crawler %s", crawler_name)
crawler = self.glue_client.start_crawler(Name=crawler_name)
return crawler
return self.glue_client.start_crawler(Name=crawler_name)

def wait_for_crawler_completion(self, crawler_name: str, poll_interval: int = 5) -> str:
"""
Expand All @@ -137,18 +135,17 @@ def wait_for_crawler_completion(self, crawler_name: str, poll_interval: int = 5)
crawler_status = crawler['LastCrawl']['Status']
if crawler_status in failed_status:
raise AirflowException(f"Status: {crawler_status}")
else:
metrics = self.glue_client.get_crawler_metrics(CrawlerNameList=[crawler_name])[
'CrawlerMetricsList'
][0]
self.log.info("Status: %s", crawler_status)
self.log.info("Last Runtime Duration (seconds): %s", metrics['LastRuntimeSeconds'])
self.log.info("Median Runtime Duration (seconds): %s", metrics['MedianRuntimeSeconds'])
self.log.info("Tables Created: %s", metrics['TablesCreated'])
self.log.info("Tables Updated: %s", metrics['TablesUpdated'])
self.log.info("Tables Deleted: %s", metrics['TablesDeleted'])

return crawler_status
metrics = self.glue_client.get_crawler_metrics(CrawlerNameList=[crawler_name])[
'CrawlerMetricsList'
][0]
self.log.info("Status: %s", crawler_status)
self.log.info("Last Runtime Duration (seconds): %s", metrics['LastRuntimeSeconds'])
self.log.info("Median Runtime Duration (seconds): %s", metrics['MedianRuntimeSeconds'])
self.log.info("Tables Created: %s", metrics['TablesCreated'])
self.log.info("Tables Updated: %s", metrics['TablesUpdated'])
self.log.info("Tables Deleted: %s", metrics['TablesDeleted'])

return crawler_status

else:
self.log.info("Polling for AWS Glue crawler: %s ", crawler_name)
Expand Down
4 changes: 1 addition & 3 deletions airflow/providers/amazon/aws/hooks/kinesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ def __init__(self, delivery_stream: str, *args, **kwargs) -> None:

def put_records(self, records: Iterable):
"""Write batch records to Kinesis Firehose"""
response = self.get_conn().put_record_batch(DeliveryStreamName=self.delivery_stream, Records=records)

return response
return self.get_conn().put_record_batch(DeliveryStreamName=self.delivery_stream, Records=records)


class AwsFirehoseHook(FirehoseHook):
Expand Down
6 changes: 2 additions & 4 deletions airflow/providers/amazon/aws/hooks/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def invoke_lambda(
"Payload": payload,
"Qualifier": qualifier,
}
response = self.conn.invoke(**{k: v for k, v in invoke_args.items() if v is not None})
return response
return self.conn.invoke(**{k: v for k, v in invoke_args.items() if v is not None})

def create_lambda(
self,
Expand Down Expand Up @@ -118,10 +117,9 @@ def create_lambda(
"CodeSigningConfigArn": code_signing_config_arn,
"Architectures": architectures,
}
response = self.conn.create_function(
return self.conn.create_function(
**{k: v for k, v in create_function_args.items() if v is not None},
)
return response


class AwsLambdaHook(LambdaHook):
Expand Down
4 changes: 1 addition & 3 deletions airflow/providers/amazon/aws/hooks/redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,4 @@ def get_conn(self) -> RedshiftConnection:
conn_params = self._get_conn_params()
conn_kwargs_dejson = self.conn.extra_dejson
conn_kwargs: Dict = {**conn_params, **conn_kwargs_dejson}
conn: RedshiftConnection = redshift_connector.connect(**conn_kwargs)

return conn
return redshift_connector.connect(**conn_kwargs)
13 changes: 5 additions & 8 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,10 @@ def list_prefixes(
Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config
)

prefixes = []
prefixes = [] # type: List[str]
for page in response:
if 'CommonPrefixes' in page:
for common_prefix in page['CommonPrefixes']:
prefixes.append(common_prefix['Prefix'])
prefixes.extend(common_prefix['Prefix'] for common_prefix in page['CommonPrefixes'])

return prefixes

Expand Down Expand Up @@ -366,12 +365,10 @@ def _is_in_period(input_date: datetime) -> bool:
StartAfter=start_after_key,
)

keys = []
keys = [] # type: List[str]
for page in response:
if 'Contents' in page:
for k in page['Contents']:
keys.append(k)

keys.extend(iter(page['Contents']))
if self.object_filter_usr is not None:
return self.object_filter_usr(keys, from_datetime, to_datetime)

Expand Down Expand Up @@ -604,7 +601,7 @@ def load_file(
extra_args['ServerSideEncryption'] = "AES256"
if gzip:
with open(filename, 'rb') as f_in:
filename_gz = f_in.name + '.gz'
filename_gz = f'{f_in.name}.gz'
with gz.open(filename_gz, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
filename = filename_gz
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/log/s3_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.
import os
import pathlib
import sys

if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -92,8 +93,7 @@ def close(self):
remote_loc = os.path.join(self.remote_base, self.log_relative_path)
if os.path.exists(local_loc):
# read log and remove old logs to get just the latest additions
with open(local_loc) as logfile:
log = logfile.read()
log = pathlib.Path(local_loc).read_text()
self.s3_write(log, remote_loc)

# Mark closed so we don't double write if close is called twice
Expand Down
3 changes: 1 addition & 2 deletions airflow/providers/amazon/aws/operators/glacier.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,4 @@ def __init__(

def execute(self, context: 'Context'):
hook = GlacierHook(aws_conn_id=self.aws_conn_id)
response = hook.retrieve_inventory(vault_name=self.vault_name)
return response
return hook.retrieve_inventory(vault_name=self.vault_name)
33 changes: 15 additions & 18 deletions airflow/providers/amazon/aws/secrets/secrets_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(
self.profile_name = profile_name
self.sep = sep
self.full_url_mode = full_url_mode
self.extra_conn_words = extra_conn_words if extra_conn_words else {}
self.extra_conn_words = extra_conn_words or {}
self.kwargs = kwargs

@cached_property
Expand Down Expand Up @@ -178,9 +178,7 @@ def get_uri_from_secret(self, secret):

conn_string = "{conn_type}://{user}:{password}@{host}:{port}/{schema}".format(**conn_d)

connection = self._format_uri_with_extra(secret, conn_string)

return connection
return self._format_uri_with_extra(secret, conn_string)

def get_conn_value(self, conn_id: str):
"""
Expand All @@ -193,20 +191,19 @@ def get_conn_value(self, conn_id: str):

if self.full_url_mode:
return self._get_secret(self.connections_prefix, conn_id)
else:
try:
secret_string = self._get_secret(self.connections_prefix, conn_id)
# json.loads gives error
secret = ast.literal_eval(secret_string) if secret_string else None
except ValueError: # 'malformed node or string: ' error, for empty conns
connection = None
secret = None

# These lines will check if we have with some denomination stored an username, password and host
if secret:
connection = self.get_uri_from_secret(secret)

return connection
try:
secret_string = self._get_secret(self.connections_prefix, conn_id)
# json.loads gives error
secret = ast.literal_eval(secret_string) if secret_string else None
except ValueError: # 'malformed node or string: ' error, for empty conns
connection = None
secret = None

# These lines will check if we have with some denomination stored an username, password and host
if secret:
connection = self.get_uri_from_secret(secret)

return connection

def get_conn_uri(self, conn_id: str) -> Optional[str]:
"""
Expand Down
3 changes: 1 addition & 2 deletions airflow/providers/amazon/aws/secrets/systems_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ def _get_secret(self, path_prefix: str, secret_id: str) -> Optional[str]:
ssm_path = self.build_path(path_prefix, secret_id)
try:
response = self.client.get_parameter(Name=ssm_path, WithDecryption=True)
value = response["Parameter"]["Value"]
return value
return response["Parameter"]["Value"]
except self.client.exceptions.ParameterNotFound:
self.log.debug("Parameter %s not found.", ssm_path)
return None
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_hook(self) -> EmrHook:
def poke(self, context: 'Context'):
response = self.get_emr_response()

if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
self.log.info('Bad HTTP response: %s', response)
return False

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/sensors/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def poke(self, context: 'Context'):
self.log.info("Exiting Job %s Run State: %s", self.run_id, job_state)
return True
elif job_state in self.errored_states:
job_error_message = "Exiting Job " + self.run_id + " Run State: " + job_state
job_error_message = f"Exiting Job {self.run_id} Run State: {job_state}"
raise AirflowException(job_error_message)
else:
return False
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/sensors/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_hook(self) -> SageMakerHook:

def poke(self, context: 'Context'):
response = self.get_sagemaker_response()
if not (response['ResponseMetadata']['HTTPStatusCode'] == 200):
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
self.log.info('Bad HTTP response: %s', response)
return False
state = self.state_from_response(response)
Expand Down Expand Up @@ -225,7 +225,7 @@ def init_log_resource(self, hook: SageMakerHook) -> None:
self.instance_count = description['ResourceConfig']['InstanceCount']
status = description['TrainingJobStatus']
job_already_completed = status not in self.non_terminal_states()
self.state = LogState.TAILING if (not job_already_completed) else LogState.COMPLETE
self.state = LogState.COMPLETE if job_already_completed else LogState.TAILING
self.last_description = description
self.last_describe_job_call = time.monotonic()
self.log_resource_inited = True
Expand Down
3 changes: 1 addition & 2 deletions airflow/providers/amazon/aws/transfers/google_api_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,12 @@ def _retrieve_data_from_google_api(self) -> dict:
api_version=self.google_api_service_version,
impersonation_chain=self.google_impersonation_chain,
)
google_api_response = google_discovery_api_hook.query(
return google_discovery_api_hook.query(
endpoint=self.google_api_endpoint_path,
data=self.google_api_endpoint_params,
paginate=self.google_api_pagination,
num_retries=self.google_api_num_retries,
)
return google_api_response

def _load_data_to_s3(self, data: dict) -> None:
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
Expand Down
5 changes: 2 additions & 3 deletions airflow/providers/amazon/aws/transfers/mysql_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def __init__(
if "header" not in pd_kwargs:
pd_kwargs["header"] = header
kwargs["pd_kwargs"] = {**kwargs.get('pd_kwargs', {}), **pd_kwargs}
else:
if pd_csv_kwargs is not None:
raise TypeError("pd_csv_kwargs may not be specified when file_format='parquet'")
elif pd_csv_kwargs is not None:
raise TypeError("pd_csv_kwargs may not be specified when file_format='parquet'")

super().__init__(sql_conn_id=mysql_conn_id, **kwargs)
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/utils/eks_get_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import argparse
import json
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone

from airflow.providers.amazon.aws.hooks.eks import EksHook

Expand All @@ -27,7 +27,8 @@


def get_expiration_time():
token_expiration = datetime.utcnow() + timedelta(minutes=TOKEN_EXPIRATION_MINUTES)
token_expiration = datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRATION_MINUTES)

return token_expiration.strftime('%Y-%m-%dT%H:%M:%SZ')


Expand Down
2 changes: 1 addition & 1 deletion tests/providers/amazon/aws/hooks/test_cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,4 @@ def test_delete_stack(self):

stacks = self.hook.get_conn().describe_stacks()['Stacks']
matching_stacks = [x for x in stacks if x['StackName'] == stack_name]
assert len(matching_stacks) == 0, f'stack with name {stack_name} should not exist'
assert not matching_stacks, f'stack with name {stack_name} should not exist'
5 changes: 2 additions & 3 deletions tests/providers/amazon/aws/log/test_s3_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.

import contextlib
import os
import unittest
from unittest import mock
Expand Down Expand Up @@ -75,10 +76,8 @@ def setUp(self):

def tearDown(self):
if self.s3_task_handler.handler:
try:
with contextlib.suppress(Exception):
os.remove(self.s3_task_handler.handler.baseFilename)
except Exception:
pass

def test_hook(self):
assert isinstance(self.s3_task_handler.hook, S3Hook)
Expand Down
3 changes: 2 additions & 1 deletion tests/providers/amazon/aws/operators/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def setUp(self):
'start_date': DEFAULT_DATE,
}

self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args, schedule_interval='@once')
self.dag = DAG(f'{TEST_DAG_ID}test_schedule_dag_once', default_args=args, schedule_interval='@once')

self.athena = AthenaOperator(
task_id='test_athena_operator',
query='SELECT * FROM TEST_TABLE',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,7 @@ def setUp(self):
"start_date": DEFAULT_DATE,
}

self.dag = DAG(
TEST_DAG_ID + "test_schedule_dag_once",
default_args=args,
schedule_interval="@once",
)
self.dag = DAG(f"{TEST_DAG_ID}test_schedule_dag_once", default_args=args, schedule_interval="@once")

def test_init(self):
dms_operator = DmsDescribeTasksOperator(
Expand Down
Loading

0 comments on commit 595981c

Please sign in to comment.