Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AwsBaseHook make client_type & resource_type optional params for get_client_type & get_resource_type #17987

Merged
merged 5 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import configparser
import datetime
import logging
import warnings
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple, Union

Expand Down Expand Up @@ -433,13 +434,22 @@ def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Se

def get_client_type(
self,
client_type: str,
client_type: Optional[str] = None,
region_name: Optional[str] = None,
config: Optional[Config] = None,
) -> boto3.client:
"""Get the underlying boto3 client using boto3 session"""
session, endpoint_url = self._get_credentials(region_name)

if client_type:
warnings.warn(
"client_type is deprecated. Set client_type from class attribute.",
DeprecationWarning,
stacklevel=2,
)
else:
client_type = self.client_type

# No AWS Operators use the config argument to this method.
# Keep backward compatibility with other users who might use it
if config is None:
Expand All @@ -449,13 +459,22 @@ def get_client_type(

def get_resource_type(
self,
resource_type: str,
resource_type: Optional[str] = None,
region_name: Optional[str] = None,
config: Optional[Config] = None,
) -> boto3.resource:
"""Get the underlying boto3 resource using boto3 session"""
session, endpoint_url = self._get_credentials(region_name)

if resource_type:
warnings.warn(
"resource_type is deprecated. Set resource_type from class attribute.",
DeprecationWarning,
stacklevel=2,
)
else:
resource_type = self.resource_type

# No AWS Operators use the config argument to this method.
# Keep backward compatibility with other users who might use it
if config is None:
Expand Down
80 changes: 79 additions & 1 deletion tests/providers/amazon/aws/hooks/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,41 @@ def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
client = boto3.client('emr', region_name='us-east-1')
if client.list_clusters()['Clusters']:
raise ValueError('AWS not properly mocked')

hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr')
client_from_hook = hook.get_client_type('emr')

assert client_from_hook.list_clusters()['Clusters'] == []

@unittest.skipIf(mock_emr is None, 'mock_emr package not present')
@mock_emr
def test_get_client_type_set_in_class_attribute(self):
client = boto3.client('emr', region_name='us-east-1')
if client.list_clusters()['Clusters']:
raise ValueError('AWS not properly mocked')
hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr')
client_from_hook = hook.get_client_type()

assert client_from_hook.list_clusters()['Clusters'] == []

@unittest.skipIf(mock_emr is None, 'mock_emr package not present')
@mock_emr
def test_get_client_type_overwrite(self):
client = boto3.client('emr', region_name='us-east-1')
if client.list_clusters()['Clusters']:
raise ValueError('AWS not properly mocked')
hook = AwsBaseHook(aws_conn_id='aws_default', client_type='dynamodb')
client_from_hook = hook.get_client_type(client_type='emr')
assert client_from_hook.list_clusters()['Clusters'] == []

@unittest.skipIf(mock_emr is None, 'mock_emr package not present')
@mock_emr
def test_get_client_type_deprecation_warning(self):
hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr')
warning_message = """client_type is deprecated. Set client_type from class attribute."""
with pytest.warns(DeprecationWarning) as warnings:
hook.get_client_type(client_type='emr')
assert warning_message == str(warnings[0].message)

@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
@mock_dynamodb2
def test_get_resource_type_returns_a_boto3_resource_of_the_requested_type(self):
Expand All @@ -137,6 +166,55 @@ def test_get_resource_type_returns_a_boto3_resource_of_the_requested_type(self):

assert table.item_count == 0

@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
@mock_dynamodb2
def test_get_resource_type_set_in_class_attribute(self):
hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='dynamodb')
resource_from_hook = hook.get_resource_type()

# this table needs to be created in production
table = resource_from_hook.create_table(
TableName='test_airflow',
KeySchema=[
{'AttributeName': 'id', 'KeyType': 'HASH'},
],
AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}],
ProvisionedThroughput={'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10},
)

table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow')

assert table.item_count == 0

@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
@mock_dynamodb2
def test_get_resource_type_overwrite(self):
hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='s3')
resource_from_hook = hook.get_resource_type('dynamodb')

# this table needs to be created in production
table = resource_from_hook.create_table(
TableName='test_airflow',
KeySchema=[
{'AttributeName': 'id', 'KeyType': 'HASH'},
],
AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}],
ProvisionedThroughput={'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10},
)

table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow')

assert table.item_count == 0

@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
@mock_dynamodb2
def test_get_resource_deprecation_warning(self):
hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='dynamodb')
warning_message = """resource_type is deprecated. Set resource_type from class attribute."""
with pytest.warns(DeprecationWarning) as warnings:
hook.get_resource_type('dynamodb')
assert warning_message == str(warnings[0].message)

@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
@mock_dynamodb2
def test_get_session_returns_a_boto3_session(self):
Expand Down