diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index 57010ed8097b8..e936557deb945 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -27,6 +27,7 @@ import configparser import datetime import logging +import warnings from functools import wraps from typing import Any, Callable, Dict, Optional, Tuple, Union @@ -433,13 +434,22 @@ def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Se def get_client_type( self, - client_type: str, + client_type: Optional[str] = None, region_name: Optional[str] = None, config: Optional[Config] = None, ) -> boto3.client: """Get the underlying boto3 client using boto3 session""" session, endpoint_url = self._get_credentials(region_name) + if client_type: + warnings.warn( + "client_type is deprecated. Set client_type from class attribute.", + DeprecationWarning, + stacklevel=2, + ) + else: + client_type = self.client_type + # No AWS Operators use the config argument to this method. # Keep backward compatibility with other users who might use it if config is None: @@ -449,13 +459,22 @@ def get_client_type( def get_resource_type( self, - resource_type: str, + resource_type: Optional[str] = None, region_name: Optional[str] = None, config: Optional[Config] = None, ) -> boto3.resource: """Get the underlying boto3 resource using boto3 session""" session, endpoint_url = self._get_credentials(region_name) + if resource_type: + warnings.warn( + "resource_type is deprecated. Set resource_type from class attribute.", + DeprecationWarning, + stacklevel=2, + ) + else: + resource_type = self.resource_type + # No AWS Operators use the config argument to this method. # Keep backward compatibility with other users who might use it if config is None: diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py index c9343878c3f74..0000136d1c63e 100644 --- a/tests/providers/amazon/aws/hooks/test_base_aws.py +++ b/tests/providers/amazon/aws/hooks/test_base_aws.py @@ -111,12 +111,41 @@ def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self): client = boto3.client('emr', region_name='us-east-1') if client.list_clusters()['Clusters']: raise ValueError('AWS not properly mocked') - hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr') client_from_hook = hook.get_client_type('emr') assert client_from_hook.list_clusters()['Clusters'] == [] + @unittest.skipIf(mock_emr is None, 'mock_emr package not present') + @mock_emr + def test_get_client_type_set_in_class_attribute(self): + client = boto3.client('emr', region_name='us-east-1') + if client.list_clusters()['Clusters']: + raise ValueError('AWS not properly mocked') + hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr') + client_from_hook = hook.get_client_type() + + assert client_from_hook.list_clusters()['Clusters'] == [] + + @unittest.skipIf(mock_emr is None, 'mock_emr package not present') + @mock_emr + def test_get_client_type_overwrite(self): + client = boto3.client('emr', region_name='us-east-1') + if client.list_clusters()['Clusters']: + raise ValueError('AWS not properly mocked') + hook = AwsBaseHook(aws_conn_id='aws_default', client_type='dynamodb') + client_from_hook = hook.get_client_type(client_type='emr') + assert client_from_hook.list_clusters()['Clusters'] == [] + + @unittest.skipIf(mock_emr is None, 'mock_emr package not present') + @mock_emr + def test_get_client_type_deprecation_warning(self): + hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr') + warning_message = """client_type is deprecated. Set client_type from class attribute.""" + with pytest.warns(DeprecationWarning) as warnings: + hook.get_client_type(client_type='emr') + assert warning_message == str(warnings[0].message) + @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present') @mock_dynamodb2 def test_get_resource_type_returns_a_boto3_resource_of_the_requested_type(self): @@ -137,6 +166,55 @@ def test_get_resource_type_returns_a_boto3_resource_of_the_requested_type(self): assert table.item_count == 0 + @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present') + @mock_dynamodb2 + def test_get_resource_type_set_in_class_attribute(self): + hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='dynamodb') + resource_from_hook = hook.get_resource_type() + + # this table needs to be created in production + table = resource_from_hook.create_table( + TableName='test_airflow', + KeySchema=[ + {'AttributeName': 'id', 'KeyType': 'HASH'}, + ], + AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], + ProvisionedThroughput={'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10}, + ) + + table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow') + + assert table.item_count == 0 + + @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present') + @mock_dynamodb2 + def test_get_resource_type_overwrite(self): + hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='s3') + resource_from_hook = hook.get_resource_type('dynamodb') + + # this table needs to be created in production + table = resource_from_hook.create_table( + TableName='test_airflow', + KeySchema=[ + {'AttributeName': 'id', 'KeyType': 'HASH'}, + ], + AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], + ProvisionedThroughput={'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10}, + ) + + table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow') + + assert table.item_count == 0 + + @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present') + @mock_dynamodb2 + def test_get_resource_deprecation_warning(self): + hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='dynamodb') + warning_message = """resource_type is deprecated. Set resource_type from class attribute.""" + with pytest.warns(DeprecationWarning) as warnings: + hook.get_resource_type('dynamodb') + assert warning_message == str(warnings[0].message) + @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present') @mock_dynamodb2 def test_get_session_returns_a_boto3_session(self):