Skip to content

Commit

Permalink
AwsBaseHook make client_type & resource_type optional params for …
Browse files Browse the repository at this point in the history
…`get_client_type` & `get_resource_type` (#17987)

* AwsBaseHook make client_type & resource_type optional params for get_client_type & get_resource_type
  • Loading branch information
eladkal authored Sep 3, 2021
1 parent decaaeb commit 867e930
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 3 deletions.
23 changes: 21 additions & 2 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import configparser
import datetime
import logging
import warnings
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple, Union

Expand Down Expand Up @@ -433,13 +434,22 @@ def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Se

def get_client_type(
self,
client_type: str,
client_type: Optional[str] = None,
region_name: Optional[str] = None,
config: Optional[Config] = None,
) -> boto3.client:
"""Get the underlying boto3 client using boto3 session"""
session, endpoint_url = self._get_credentials(region_name)

if client_type:
warnings.warn(
"client_type is deprecated. Set client_type from class attribute.",
DeprecationWarning,
stacklevel=2,
)
else:
client_type = self.client_type

# No AWS Operators use the config argument to this method.
# Keep backward compatibility with other users who might use it
if config is None:
Expand All @@ -449,13 +459,22 @@ def get_client_type(

def get_resource_type(
self,
resource_type: str,
resource_type: Optional[str] = None,
region_name: Optional[str] = None,
config: Optional[Config] = None,
) -> boto3.resource:
"""Get the underlying boto3 resource using boto3 session"""
session, endpoint_url = self._get_credentials(region_name)

if resource_type:
warnings.warn(
"resource_type is deprecated. Set resource_type from class attribute.",
DeprecationWarning,
stacklevel=2,
)
else:
resource_type = self.resource_type

# No AWS Operators use the config argument to this method.
# Keep backward compatibility with other users who might use it
if config is None:
Expand Down
80 changes: 79 additions & 1 deletion tests/providers/amazon/aws/hooks/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,41 @@ def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
client = boto3.client('emr', region_name='us-east-1')
if client.list_clusters()['Clusters']:
raise ValueError('AWS not properly mocked')

hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr')
client_from_hook = hook.get_client_type('emr')

assert client_from_hook.list_clusters()['Clusters'] == []

@unittest.skipIf(mock_emr is None, 'mock_emr package not present')
@mock_emr
def test_get_client_type_set_in_class_attribute(self):
client = boto3.client('emr', region_name='us-east-1')
if client.list_clusters()['Clusters']:
raise ValueError('AWS not properly mocked')
hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr')
client_from_hook = hook.get_client_type()

assert client_from_hook.list_clusters()['Clusters'] == []

@unittest.skipIf(mock_emr is None, 'mock_emr package not present')
@mock_emr
def test_get_client_type_overwrite(self):
client = boto3.client('emr', region_name='us-east-1')
if client.list_clusters()['Clusters']:
raise ValueError('AWS not properly mocked')
hook = AwsBaseHook(aws_conn_id='aws_default', client_type='dynamodb')
client_from_hook = hook.get_client_type(client_type='emr')
assert client_from_hook.list_clusters()['Clusters'] == []

@unittest.skipIf(mock_emr is None, 'mock_emr package not present')
@mock_emr
def test_get_client_type_deprecation_warning(self):
hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr')
warning_message = """client_type is deprecated. Set client_type from class attribute."""
with pytest.warns(DeprecationWarning) as warnings:
hook.get_client_type(client_type='emr')
assert warning_message == str(warnings[0].message)

@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
@mock_dynamodb2
def test_get_resource_type_returns_a_boto3_resource_of_the_requested_type(self):
Expand All @@ -137,6 +166,55 @@ def test_get_resource_type_returns_a_boto3_resource_of_the_requested_type(self):

assert table.item_count == 0

@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
@mock_dynamodb2
def test_get_resource_type_set_in_class_attribute(self):
hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='dynamodb')
resource_from_hook = hook.get_resource_type()

# this table needs to be created in production
table = resource_from_hook.create_table(
TableName='test_airflow',
KeySchema=[
{'AttributeName': 'id', 'KeyType': 'HASH'},
],
AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}],
ProvisionedThroughput={'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10},
)

table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow')

assert table.item_count == 0

@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
@mock_dynamodb2
def test_get_resource_type_overwrite(self):
hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='s3')
resource_from_hook = hook.get_resource_type('dynamodb')

# this table needs to be created in production
table = resource_from_hook.create_table(
TableName='test_airflow',
KeySchema=[
{'AttributeName': 'id', 'KeyType': 'HASH'},
],
AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}],
ProvisionedThroughput={'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10},
)

table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow')

assert table.item_count == 0

@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
@mock_dynamodb2
def test_get_resource_deprecation_warning(self):
hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='dynamodb')
warning_message = """resource_type is deprecated. Set resource_type from class attribute."""
with pytest.warns(DeprecationWarning) as warnings:
hook.get_resource_type('dynamodb')
assert warning_message == str(warnings[0].message)

@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
@mock_dynamodb2
def test_get_session_returns_a_boto3_session(self):
Expand Down

0 comments on commit 867e930

Please sign in to comment.