diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index 5e518adde04ac..28135b0fa4ae9 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -26,6 +26,7 @@ import configparser import datetime +import json import logging import warnings from functools import wraps @@ -409,9 +410,6 @@ def __init__( self.region_name = region_name self.config = config - if not (self.client_type or self.resource_type): - raise AirflowException('Either client_type or resource_type must be provided.') - def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]: if not self.aws_conn_id: @@ -510,13 +508,15 @@ def conn(self) -> Union[boto3.client, boto3.resource]: :return: boto3.client or boto3.resource :rtype: Union[boto3.client, boto3.resource] """ - if self.client_type: + if not ((not self.client_type) ^ (not self.resource_type)): + raise ValueError( + f"Either client_type={self.client_type!r} or " + f"resource_type={self.resource_type!r} must be provided, not both." + ) + elif self.client_type: return self.get_client_type(region_name=self.region_name) - elif self.resource_type: - return self.get_resource_type(region_name=self.region_name) else: - # Rare possibility - subclasses have not specified a client_type or resource_type - raise NotImplementedError('Could not get boto3 connection!') + return self.get_resource_type(region_name=self.region_name) @cached_property def conn_client_meta(self) -> ClientMeta: @@ -611,6 +611,29 @@ def decorator_f(self, *args, **kwargs): return retry_decorator + def test_connection(self): + """ + Tests the AWS connection by call AWS STS (Security Token Service) GetCallerIdentity API. + + .. seealso:: + https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html + """ + orig_client_type, self.client_type = self.client_type, 'sts' + try: + res = self.get_client_type().get_caller_identity() + metadata = res.pop("ResponseMetadata", {}) + if metadata.get("HTTPStatusCode") == 200: + return True, json.dumps(res) + else: + try: + return False, json.dumps(metadata) + except TypeError: + return False, str(metadata) + except Exception as e: + return False, str(e) + finally: + self.client_type = orig_client_type + class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]): """ diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py index 00ffc31163057..d1c3da9cdd5ea 100644 --- a/tests/providers/amazon/aws/hooks/test_base_aws.py +++ b/tests/providers/amazon/aws/hooks/test_base_aws.py @@ -701,6 +701,29 @@ def test_connection_aws_partition(self, conn_type, connection_uri, expected_part assert hook.conn_partition == expected_partition + @pytest.mark.parametrize( + "client_type,resource_type", + [ + ("s3", "dynamodb"), + (None, None), + ("", ""), + ], + ) + def test_connection_client_resource_types_check(self, client_type, resource_type): + # Should not raise any error during Hook initialisation. + hook = AwsBaseHook(aws_conn_id=None, client_type=client_type, resource_type=resource_type) + + with pytest.raises(ValueError, match="Either client_type=.* or resource_type=.* must be provided"): + hook.get_conn() + + @unittest.skipIf(mock_sts is None, 'mock_sts package not present') + @mock_sts + def test_hook_connection_test(self): + hook = AwsBaseHook(client_type="s3") + result, message = hook.test_connection() + assert result + assert hook.client_type == "s3" # Same client_type which defined during initialisation + class ThrowErrorUntilCount: """Holds counter state for invoking a method several times in a row."""