Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test_connection method to AWS hook #24662

Merged
merged 1 commit into from
Jul 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import configparser
import datetime
import json
import logging
import warnings
from functools import wraps
Expand Down Expand Up @@ -409,9 +410,6 @@ def __init__(
self.region_name = region_name
self.config = config

if not (self.client_type or self.resource_type):
raise AirflowException('Either client_type or resource_type must be provided.')

def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]:

if not self.aws_conn_id:
Expand Down Expand Up @@ -510,13 +508,15 @@ def conn(self) -> Union[boto3.client, boto3.resource]:
:return: boto3.client or boto3.resource
:rtype: Union[boto3.client, boto3.resource]
"""
if self.client_type:
if not ((not self.client_type) ^ (not self.resource_type)):
raise ValueError(
f"Either client_type={self.client_type!r} or "
f"resource_type={self.resource_type!r} must be provided, not both."
)
elif self.client_type:
return self.get_client_type(region_name=self.region_name)
elif self.resource_type:
return self.get_resource_type(region_name=self.region_name)
else:
# Rare possibility - subclasses have not specified a client_type or resource_type
raise NotImplementedError('Could not get boto3 connection!')
return self.get_resource_type(region_name=self.region_name)

@cached_property
def conn_client_meta(self) -> ClientMeta:
Expand Down Expand Up @@ -611,6 +611,29 @@ def decorator_f(self, *args, **kwargs):

return retry_decorator

def test_connection(self):
"""
Tests the AWS connection by call AWS STS (Security Token Service) GetCallerIdentity API.

.. seealso::
https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html
"""
orig_client_type, self.client_type = self.client_type, 'sts'
try:
res = self.get_client_type().get_caller_identity()
metadata = res.pop("ResponseMetadata", {})
if metadata.get("HTTPStatusCode") == 200:
return True, json.dumps(res)
else:
try:
return False, json.dumps(metadata)
except TypeError:
return False, str(metadata)
except Exception as e:
return False, str(e)
finally:
self.client_type = orig_client_type


class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]):
"""
Expand Down
23 changes: 23 additions & 0 deletions tests/providers/amazon/aws/hooks/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,29 @@ def test_connection_aws_partition(self, conn_type, connection_uri, expected_part

assert hook.conn_partition == expected_partition

@pytest.mark.parametrize(
"client_type,resource_type",
[
("s3", "dynamodb"),
(None, None),
("", ""),
],
)
def test_connection_client_resource_types_check(self, client_type, resource_type):
# Should not raise any error during Hook initialisation.
hook = AwsBaseHook(aws_conn_id=None, client_type=client_type, resource_type=resource_type)

with pytest.raises(ValueError, match="Either client_type=.* or resource_type=.* must be provided"):
hook.get_conn()

@unittest.skipIf(mock_sts is None, 'mock_sts package not present')
@mock_sts
def test_hook_connection_test(self):
hook = AwsBaseHook(client_type="s3")
result, message = hook.test_connection()
assert result
assert hook.client_type == "s3" # Same client_type which defined during initialisation


class ThrowErrorUntilCount:
"""Holds counter state for invoking a method several times in a row."""
Expand Down