From bbcc72f81a665edc4b5b3768beeb62999f382e93 Mon Sep 17 00:00:00 2001 From: KieranBrantnerMagee Date: Wed, 9 Sep 2020 21:45:25 -0700 Subject: [PATCH] [ServiceBus] Support SAS token-via-connection-string auth, and remove ServiceBusSharedKeyCredential export (#13627) - Remove public documentation and exports of ServiceBusSharedKeyCredential until we chose to release it across all languages. - Support for Sas Token connection strings (tests, etc) - Add safety net for if signature and key are both provided in connstr (inspired by .nets approach) Co-authored-by: Rakshith Bhyravabhotla --- sdk/servicebus/azure-servicebus/CHANGELOG.md | 2 + .../azure/servicebus/__init__.py | 2 - .../azure/servicebus/_base_handler.py | 106 +++++++++++++----- .../azure/servicebus/_common/utils.py | 40 +++++-- .../azure/servicebus/_servicebus_client.py | 18 ++- .../azure/servicebus/_servicebus_receiver.py | 9 +- .../azure/servicebus/_servicebus_sender.py | 9 +- .../_servicebus_session_receiver.py | 4 +- .../azure/servicebus/aio/__init__.py | 2 - .../servicebus/aio/_base_handler_async.py | 39 ++++++- .../aio/_servicebus_client_async.py | 14 ++- .../aio/_servicebus_receiver_async.py | 10 +- .../aio/_servicebus_sender_async.py | 10 +- .../aio/_servicebus_session_receiver_async.py | 4 +- .../management/_management_client_async.py | 14 ++- .../management/_management_client.py | 14 ++- .../mgmt_tests/test_mgmt_queues_async.py | 2 +- .../tests/async_tests/test_sb_client_async.py | 30 +++++ .../async_tests/test_subscriptions_async.py | 3 +- .../tests/async_tests/test_topic_async.py | 3 +- .../tests/mgmt_tests/test_mgmt_queues.py | 2 +- .../azure-servicebus/tests/test_sb_client.py | 40 ++++++- .../tests/test_subscriptions.py | 3 +- .../azure-servicebus/tests/test_topic.py | 3 +- 24 files changed, 290 insertions(+), 93 deletions(-) diff --git a/sdk/servicebus/azure-servicebus/CHANGELOG.md b/sdk/servicebus/azure-servicebus/CHANGELOG.md index cd1d87c13cc8..ff0b31391d7e 100644 --- a/sdk/servicebus/azure-servicebus/CHANGELOG.md +++ b/sdk/servicebus/azure-servicebus/CHANGELOG.md @@ -7,6 +7,7 @@ * `renew_lock()` now returns the UTC datetime that the lock is set to expire at. * `receive_deferred_messages()` can now take a single sequence number as well as a list of sequence numbers. * Messages can now be sent twice in succession. +* Connection strings used with `from_connection_string` methods now support using the `SharedAccessSignature` key in leiu of `sharedaccesskey` and `sharedaccesskeyname`, taking the string of the properly constructed token as value. * Internal AMQP message properties (header, footer, annotations, properties, etc) are now exposed via `Message.amqp_message` **Breaking Changes** @@ -31,6 +32,7 @@ * Remove `support_ordering` from `create_queue` and `QueueProperties` * Remove `enable_subscription_partitioning` from `create_topic` and `TopicProperties` * `get_dead_letter_[queue,subscription]_receiver()` has been removed. To connect to a dead letter queue, utilize the `sub_queue` parameter of `get_[queue,subscription]_receiver()` provided with a value from the `SubQueue` enum +* No longer export `ServiceBusSharedKeyCredential` * Rename `entity_availability_status` to `availability_status` ## 7.0.0b5 (2020-08-10) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/__init__.py b/sdk/servicebus/azure-servicebus/azure/servicebus/__init__.py index 5ee9a06b6b01..80c9eaaaabca 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/__init__.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/__init__.py @@ -13,7 +13,6 @@ from ._servicebus_receiver import ServiceBusReceiver from ._servicebus_session_receiver import ServiceBusSessionReceiver from ._servicebus_session import ServiceBusSession -from ._base_handler import ServiceBusSharedKeyCredential from ._common.message import Message, BatchMessage, PeekedMessage, ReceivedMessage from ._common.constants import ReceiveMode, SubQueue from ._common.auto_lock_renewer import AutoLockRenew @@ -32,7 +31,6 @@ 'ServiceBusSessionReceiver', 'ServiceBusSession', 'ServiceBusSender', - 'ServiceBusSharedKeyCredential', 'TransportType', 'AutoLockRenew' ] diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py index 80b181d116a9..545fa074807f 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py @@ -7,7 +7,7 @@ import uuid import time from datetime import timedelta -from typing import cast, Optional, Tuple, TYPE_CHECKING, Dict, Any, Callable, Type +from typing import cast, Optional, Tuple, TYPE_CHECKING, Dict, Any, Callable try: from urllib import quote_plus # type: ignore @@ -18,6 +18,8 @@ from uamqp import utils from uamqp.message import MessageProperties +from azure.core.credentials import AccessToken + from ._common._configuration import Configuration from .exceptions import ( ServiceBusError, @@ -41,11 +43,13 @@ def _parse_conn_str(conn_str): - # type: (str) -> Tuple[str, str, str, str] + # type: (str) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]] endpoint = None shared_access_key_name = None shared_access_key = None entity_path = None # type: Optional[str] + shared_access_signature = None # type: Optional[str] + shared_access_signature_expiry = None # type: Optional[int] for element in conn_str.split(";"): key, _, value = element.partition("=") if key.lower() == "endpoint": @@ -58,10 +62,21 @@ def _parse_conn_str(conn_str): shared_access_key = value elif key.lower() == "entitypath": entity_path = value - if not all([endpoint, shared_access_key_name, shared_access_key]): + elif key.lower() == "sharedaccesssignature": + shared_access_signature = value + try: + # Expiry can be stored in the "se=" clause of the token. ('&'-separated key-value pairs) + # type: ignore + shared_access_signature_expiry = int(shared_access_signature.split('se=')[1].split('&')[0]) + except (IndexError, TypeError, ValueError): # Fallback since technically expiry is optional. + # An arbitrary, absurdly large number, since you can't renew. + shared_access_signature_expiry = int(time.time() * 2) + if not (all((endpoint, shared_access_key_name, shared_access_key)) or all((endpoint, shared_access_signature))) \ + or all((shared_access_key_name, shared_access_signature)): # this latter clause since we don't accept both raise ValueError( "Invalid connection string. Should be in the format: " "Endpoint=sb:///;SharedAccessKeyName=;SharedAccessKey=" + "\nWith alternate option of providing SharedAccessSignature instead of SharedAccessKeyName and Key" ) entity = cast(str, entity_path) left_slash_pos = cast(str, endpoint).find("//") @@ -69,7 +84,13 @@ def _parse_conn_str(conn_str): host = cast(str, endpoint)[left_slash_pos + 2:] else: host = str(endpoint) - return host, str(shared_access_key_name), str(shared_access_key), entity + + return (host, + str(shared_access_key_name) if shared_access_key_name else None, + str(shared_access_key) if shared_access_key else None, + entity, + str(shared_access_signature) if shared_access_signature else None, + shared_access_signature_expiry) def _generate_sas_token(uri, policy, key, expiry=None): @@ -90,29 +111,27 @@ def _generate_sas_token(uri, policy, key, expiry=None): return _AccessToken(token=token, expires_on=abs_expiry) -def _convert_connection_string_to_kwargs(conn_str, shared_key_credential_type, **kwargs): - # type: (str, Type, Any) -> Dict[str, Any] - host, policy, key, entity_in_conn_str = _parse_conn_str(conn_str) - queue_name = kwargs.get("queue_name") - topic_name = kwargs.get("topic_name") - if not (queue_name or topic_name or entity_in_conn_str): - raise ValueError("Entity name is missing. Please specify `queue_name` or `topic_name`" - " or use a connection string including the entity information.") - - if queue_name and topic_name: - raise ValueError("`queue_name` and `topic_name` can not be specified simultaneously.") - - entity_in_kwargs = queue_name or topic_name - if entity_in_conn_str and entity_in_kwargs and (entity_in_conn_str != entity_in_kwargs): - raise ServiceBusAuthenticationError( - "Entity names do not match, the entity name in connection string is {};" - " the entity name in parameter is {}.".format(entity_in_conn_str, entity_in_kwargs) - ) +class ServiceBusSASTokenCredential(object): + """The shared access token credential used for authentication. + :param str token: The shared access token string + :param int expiry: The epoch timestamp + """ + def __init__(self, token, expiry): + # type: (str, int) -> None + """ + :param str token: The shared access token string + :param float expiry: The epoch timestamp + """ + self.token = token + self.expiry = expiry + self.token_type = b"servicebus.windows.net:sastoken" - kwargs["fully_qualified_namespace"] = host - kwargs["entity_name"] = entity_in_conn_str or entity_in_kwargs - kwargs["credential"] = shared_key_credential_type(policy, key) - return kwargs + def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument + # type: (str, Any) -> AccessToken + """ + This method is automatically called when token is about to expire. + """ + return AccessToken(self.token, self.expiry) class ServiceBusSharedKeyCredential(object): @@ -158,6 +177,41 @@ def __init__( self._auth_uri = None self._properties = create_properties(self._config.user_agent) + @classmethod + def _convert_connection_string_to_kwargs(cls, conn_str, **kwargs): + # type: (str, Any) -> Dict[str, Any] + host, policy, key, entity_in_conn_str, token, token_expiry = _parse_conn_str(conn_str) + queue_name = kwargs.get("queue_name") + topic_name = kwargs.get("topic_name") + if not (queue_name or topic_name or entity_in_conn_str): + raise ValueError("Entity name is missing. Please specify `queue_name` or `topic_name`" + " or use a connection string including the entity information.") + + if queue_name and topic_name: + raise ValueError("`queue_name` and `topic_name` can not be specified simultaneously.") + + entity_in_kwargs = queue_name or topic_name + if entity_in_conn_str and entity_in_kwargs and (entity_in_conn_str != entity_in_kwargs): + raise ServiceBusAuthenticationError( + "Entity names do not match, the entity name in connection string is {};" + " the entity name in parameter is {}.".format(entity_in_conn_str, entity_in_kwargs) + ) + + kwargs["fully_qualified_namespace"] = host + kwargs["entity_name"] = entity_in_conn_str or entity_in_kwargs + # This has to be defined seperately to support sync vs async credentials. + kwargs["credential"] = cls._create_credential_from_connection_string_parameters(token, + token_expiry, + policy, + key) + return kwargs + + @classmethod + def _create_credential_from_connection_string_parameters(cls, token, token_expiry, policy, key): + if token and token_expiry: + return ServiceBusSASTokenCredential(token, token_expiry) + return ServiceBusSharedKeyCredential(policy, key) + def __enter__(self): self._open_with_retry() return self diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/utils.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/utils.py index 65241798087b..bc6c84cb1e32 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/utils.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/utils.py @@ -9,7 +9,8 @@ import logging import functools import platform -from typing import Optional, Dict +import time +from typing import Optional, Dict, Tuple try: from urlparse import urlparse except ImportError: @@ -63,11 +64,15 @@ def utc_now(): return datetime.datetime.now(tz=TZ_UTC) +# This parse_conn_str is used for mgmt, the other in base_handler for handlers. Should be unified. def parse_conn_str(conn_str): - endpoint = None - shared_access_key_name = None - shared_access_key = None - entity_path = None + # type: (str) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]] + endpoint = '' + shared_access_key_name = None # type: Optional[str] + shared_access_key = None # type: Optional[str] + entity_path = '' + shared_access_signature = None # type: Optional[str] + shared_access_signature_expiry = None # type: Optional[int] for element in conn_str.split(';'): key, _, value = element.partition('=') if key.lower() == 'endpoint': @@ -78,9 +83,28 @@ def parse_conn_str(conn_str): shared_access_key = value elif key.lower() == 'entitypath': entity_path = value - if not all([endpoint, shared_access_key_name, shared_access_key]): - raise ValueError("Invalid connection string") - return endpoint, shared_access_key_name, shared_access_key, entity_path + elif key.lower() == "sharedaccesssignature": + shared_access_signature = value + try: + # Expiry can be stored in the "se=" clause of the token. ('&'-separated key-value pairs) + # type: ignore + shared_access_signature_expiry = int(shared_access_signature.split('se=')[1].split('&')[0]) + except (IndexError, TypeError, ValueError): # Fallback since technically expiry is optional. + # An arbitrary, absurdly large number, since you can't renew. + shared_access_signature_expiry = int(time.time() * 2) + if not (all((endpoint, shared_access_key_name, shared_access_key)) or all((endpoint, shared_access_signature))) \ + or all((shared_access_key_name, shared_access_signature)): # this latter clause since we don't accept both + raise ValueError( + "Invalid connection string. Should be in the format: " + "Endpoint=sb:///;SharedAccessKeyName=;SharedAccessKey=" + "\nWith alternate option of providing SharedAccessSignature instead of SharedAccessKeyName and Key" + ) + return (endpoint, + str(shared_access_key_name) if shared_access_key_name else None, + str(shared_access_key) if shared_access_key else None, + entity_path, + str(shared_access_signature) if shared_access_signature else None, + shared_access_signature_expiry) def build_uri(address, entity): diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py index 7e9d727ada70..8fcbc9efcaf8 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py @@ -7,7 +7,11 @@ import uamqp -from ._base_handler import _parse_conn_str, ServiceBusSharedKeyCredential, BaseHandler +from ._base_handler import ( + _parse_conn_str, + ServiceBusSharedKeyCredential, + ServiceBusSASTokenCredential, + BaseHandler) from ._servicebus_sender import ServiceBusSender from ._servicebus_receiver import ServiceBusReceiver from ._servicebus_session_receiver import ServiceBusSessionReceiver @@ -33,8 +37,8 @@ class ServiceBusClient(object): The namespace format is: `.servicebus.windows.net`. :param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which implements a particular interface for getting tokens. It accepts - :class:`ServiceBusSharedKeyCredential`, or credential objects - generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method. + credential objects generated by the azure-identity library and objects that implement the + `get_token(self, *scopes)` method. :keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`. :keyword transport_type: The type of transport protocol that will be used for communicating with the Service Bus service. Default is `TransportType.Amqp`. @@ -153,11 +157,15 @@ def from_connection_string( :caption: Create a new instance of the ServiceBusClient from connection string. """ - host, policy, key, entity_in_conn_str = _parse_conn_str(conn_str) + host, policy, key, entity_in_conn_str, token, token_expiry = _parse_conn_str(conn_str) + if token and token_expiry: + credential = ServiceBusSASTokenCredential(token, token_expiry) + elif policy and key: + credential = ServiceBusSharedKeyCredential(policy, key) # type: ignore return cls( fully_qualified_namespace=host, entity_name=entity_in_conn_str or kwargs.pop("entity_name", None), - credential=ServiceBusSharedKeyCredential(policy, key), # type: ignore + credential=credential, # type: ignore **kwargs ) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_receiver.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_receiver.py index aefd25af3fe5..9341fbc87baf 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_receiver.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_receiver.py @@ -13,7 +13,7 @@ from uamqp.constants import SenderSettleMode from uamqp.authentication.common import AMQPAuth -from ._base_handler import BaseHandler, ServiceBusSharedKeyCredential, _convert_connection_string_to_kwargs +from ._base_handler import BaseHandler from ._common.utils import create_authentication from ._common.message import PeekedMessage, ReceivedMessage from ._common.constants import ( @@ -56,8 +56,8 @@ class ServiceBusReceiver(BaseHandler, ReceiverMixin): # pylint: disable=too-man The namespace format is: `.servicebus.windows.net`. :param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which implements a particular interface for getting tokens. It accepts - :class:`ServiceBusSharedKeyCredential`, or credential objects - generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method. + :class: credential objects generated by the azure-identity library and objects that implement the + `get_token(self, *scopes)` method. :keyword str queue_name: The path of specific Service Bus Queue the client connects to. :keyword str topic_name: The path of specific Service Bus Topic which contains the Subscription the client connects to. @@ -363,9 +363,8 @@ def from_connection_string( :caption: Create a new instance of the ServiceBusReceiver from connection string. """ - constructor_args = _convert_connection_string_to_kwargs( + constructor_args = cls._convert_connection_string_to_kwargs( conn_str, - ServiceBusSharedKeyCredential, **kwargs ) if kwargs.get("queue_name") and kwargs.get("subscription_name"): diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_sender.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_sender.py index f2e099d712c1..d664f1165aef 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_sender.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_sender.py @@ -11,7 +11,7 @@ from uamqp import SendClient, types from uamqp.authentication.common import AMQPAuth -from ._base_handler import BaseHandler, ServiceBusSharedKeyCredential, _convert_connection_string_to_kwargs +from ._base_handler import BaseHandler from ._common import mgmt_handlers from ._common.message import Message, BatchMessage from .exceptions import ( @@ -97,8 +97,8 @@ class ServiceBusSender(BaseHandler, SenderMixin): The namespace format is: `.servicebus.windows.net`. :param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which implements a particular interface for getting tokens. It accepts - :class:`ServiceBusSharedKeyCredential`, or credential objects - generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method. + :class: credential objects generated by the azure-identity library and objects that implement the + `get_token(self, *scopes)` method. :keyword str queue_name: The path of specific Service Bus Queue the client connects to. :keyword str topic_name: The path of specific Service Bus Topic the client connects to. :keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`. @@ -293,9 +293,8 @@ def from_connection_string( :caption: Create a new instance of the ServiceBusSender from connection string. """ - constructor_args = _convert_connection_string_to_kwargs( + constructor_args = cls._convert_connection_string_to_kwargs( conn_str, - ServiceBusSharedKeyCredential, **kwargs ) return cls(**constructor_args) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_session_receiver.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_session_receiver.py index 0393b487c098..51381600da61 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_session_receiver.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_session_receiver.py @@ -33,8 +33,8 @@ class ServiceBusSessionReceiver(ServiceBusReceiver, SessionReceiverMixin): The namespace format is: `.servicebus.windows.net`. :param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which implements a particular interface for getting tokens. It accepts - :class:`ServiceBusSharedKeyCredential`, or credential objects - generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method. + :class: credential objects generated by the azure-identity library and objects that implement the + `get_token(self, *scopes)` method. :keyword str queue_name: The path of specific Service Bus Queue the client connects to. :keyword str topic_name: The path of specific Service Bus Topic which contains the Subscription the client connects to. diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/__init__.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/__init__.py index 4e13bf070d49..0cf6aa9cdc25 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/__init__.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/__init__.py @@ -4,7 +4,6 @@ # license information. # ------------------------------------------------------------------------- from ._async_message import ReceivedMessage -from ._base_handler_async import ServiceBusSharedKeyCredential from ._servicebus_sender_async import ServiceBusSender from ._servicebus_receiver_async import ServiceBusReceiver from ._servicebus_session_receiver_async import ServiceBusSessionReceiver @@ -19,6 +18,5 @@ 'ServiceBusReceiver', 'ServiceBusSessionReceiver', 'ServiceBusSession', - 'ServiceBusSharedKeyCredential', 'AutoLockRenew' ] diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_base_handler_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_base_handler_async.py index 125574319d79..d6fb5d3722aa 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_base_handler_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_base_handler_async.py @@ -9,7 +9,10 @@ import uamqp from uamqp.message import MessageProperties -from .._base_handler import _generate_sas_token, _AccessToken + +from azure.core.credentials import AccessToken + +from .._base_handler import _generate_sas_token, _AccessToken, BaseHandler as BaseHandlerSync from .._common._configuration import Configuration from .._common.utils import create_properties from .._common.constants import ( @@ -23,11 +26,32 @@ ) if TYPE_CHECKING: - from azure.core.credentials import TokenCredential, AccessToken + from azure.core.credentials import TokenCredential _LOGGER = logging.getLogger(__name__) +class ServiceBusSASTokenCredential(object): + """The shared access token credential used for authentication. + :param str token: The shared access token string + :param int expiry: The epoch timestamp + """ + def __init__(self, token: str, expiry: int) -> None: + """ + :param str token: The shared access token string + :param int expiry: The epoch timestamp + """ + self.token = token + self.expiry = expiry + self.token_type = b"servicebus.windows.net:sastoken" + + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument + """ + This method is automatically called when token is about to expire. + """ + return AccessToken(self.token, self.expiry) + + class ServiceBusSharedKeyCredential(object): """The shared access key credential used for authentication. @@ -69,6 +93,17 @@ def __init__( self._auth_uri = None self._properties = create_properties(self._config.user_agent) + @classmethod + def _convert_connection_string_to_kwargs(cls, conn_str, **kwargs): + # pylint:disable=protected-access + return BaseHandlerSync._convert_connection_string_to_kwargs(conn_str, **kwargs) + + @classmethod + def _create_credential_from_connection_string_parameters(cls, token, token_expiry, policy, key): + if token and token_expiry: + return ServiceBusSASTokenCredential(token, token_expiry) + return ServiceBusSharedKeyCredential(policy, key) + async def __aenter__(self): await self._open_with_retry() return self diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py index dc291e08f91a..83134591c845 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py @@ -8,7 +8,7 @@ import uamqp from .._base_handler import _parse_conn_str -from ._base_handler_async import ServiceBusSharedKeyCredential, BaseHandler +from ._base_handler_async import ServiceBusSharedKeyCredential, ServiceBusSASTokenCredential, BaseHandler from ._servicebus_sender_async import ServiceBusSender from ._servicebus_receiver_async import ServiceBusReceiver from ._servicebus_session_receiver_async import ServiceBusSessionReceiver @@ -35,8 +35,8 @@ class ServiceBusClient(object): The namespace format is: `.servicebus.windows.net`. :param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which implements a particular interface for getting tokens. It accepts - :class:`ServiceBusSharedKeyCredential`, or credential objects - generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method. + credential objects generated by the azure-identity library and objects that implement the + `get_token(self, *scopes)` method. :keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`. :keyword transport_type: The type of transport protocol that will be used for communicating with the Service Bus service. Default is `TransportType.Amqp`. @@ -133,11 +133,15 @@ def from_connection_string( :caption: Create a new instance of the ServiceBusClient from connection string. """ - host, policy, key, entity_in_conn_str = _parse_conn_str(conn_str) + host, policy, key, entity_in_conn_str, token, token_expiry = _parse_conn_str(conn_str) + if token and token_expiry: + credential = ServiceBusSASTokenCredential(token, token_expiry) + elif policy and key: + credential = ServiceBusSharedKeyCredential(policy, key) # type: ignore return cls( fully_qualified_namespace=host, entity_name=entity_in_conn_str or kwargs.pop("entity_name", None), - credential=ServiceBusSharedKeyCredential(policy, key), # type: ignore + credential=credential, # type: ignore **kwargs ) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_receiver_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_receiver_async.py index 9ca36042b814..c6081c2557f7 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_receiver_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_receiver_async.py @@ -13,9 +13,8 @@ from uamqp import ReceiveClientAsync, types, Message from uamqp.constants import SenderSettleMode -from ._base_handler_async import BaseHandler, ServiceBusSharedKeyCredential +from ._base_handler_async import BaseHandler from ._async_message import ReceivedMessage -from .._base_handler import _convert_connection_string_to_kwargs from .._common.receiver_mixins import ReceiverMixin from .._common.constants import ( REQUEST_RESPONSE_UPDATE_DISPOSTION_OPERATION, @@ -56,8 +55,8 @@ class ServiceBusReceiver(collections.abc.AsyncIterator, BaseHandler, ReceiverMix The namespace format is: `.servicebus.windows.net`. :param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which implements a particular interface for getting tokens. It accepts - :class:`ServiceBusSharedKeyCredential`, or credential objects - generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method. + :class: credential objects generated by the azure-identity library and objects that implement the + `get_token(self, *scopes)` method. :keyword str queue_name: The path of specific Service Bus Queue the client connects to. :keyword str topic_name: The path of specific Service Bus Topic which contains the Subscription the client connects to. @@ -358,9 +357,8 @@ def from_connection_string( :caption: Create a new instance of the ServiceBusReceiver from connection string. """ - constructor_args = _convert_connection_string_to_kwargs( + constructor_args = cls._convert_connection_string_to_kwargs( conn_str, - ServiceBusSharedKeyCredential, **kwargs ) if kwargs.get("queue_name") and kwargs.get("subscription_name"): diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_sender_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_sender_async.py index 93424d9b263d..082a9371bf65 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_sender_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_sender_async.py @@ -10,9 +10,8 @@ from uamqp import SendClientAsync, types from .._common.message import Message, BatchMessage -from .._base_handler import _convert_connection_string_to_kwargs from .._servicebus_sender import SenderMixin -from ._base_handler_async import BaseHandler, ServiceBusSharedKeyCredential +from ._base_handler_async import BaseHandler from .._common.constants import ( REQUEST_RESPONSE_SCHEDULE_MESSAGE_OPERATION, REQUEST_RESPONSE_CANCEL_SCHEDULED_MESSAGE_OPERATION, @@ -43,8 +42,8 @@ class ServiceBusSender(BaseHandler, SenderMixin): The namespace format is: `.servicebus.windows.net`. :param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which implements a particular interface for getting tokens. It accepts - :class:`ServiceBusSharedKeyCredential`, or credential objects - generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method. + :class: credential objects generated by the azure-identity library and objects that implement the + `get_token(self, *scopes)` method. :keyword str queue_name: The path of specific Service Bus Queue the client connects to. Only one of queue_name or topic_name can be provided. :keyword str topic_name: The path of specific Service Bus Topic the client connects to. @@ -232,9 +231,8 @@ def from_connection_string( :caption: Create a new instance of the ServiceBusSender from connection string. """ - constructor_args = _convert_connection_string_to_kwargs( + constructor_args = cls._convert_connection_string_to_kwargs( conn_str, - ServiceBusSharedKeyCredential, **kwargs ) return cls(**constructor_args) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_session_receiver_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_session_receiver_async.py index c6ccaa0dfb1a..8685fcc8664e 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_session_receiver_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_session_receiver_async.py @@ -33,8 +33,8 @@ class ServiceBusSessionReceiver(ServiceBusReceiver, SessionReceiverMixin): The namespace format is: `.servicebus.windows.net`. :param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which implements a particular interface for getting tokens. It accepts - :class:`ServiceBusSharedKeyCredential`, or credential objects - generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method. + :class: credential objects generated by the azure-identity library and objects that implement the + `get_token(self, *scopes)` method. :keyword str queue_name: The path of specific Service Bus Queue the client connects to. :keyword str topic_name: The path of specific Service Bus Topic which contains the Subscription the client connects to. diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/management/_management_client_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/management/_management_client_async.py index e495126304a1..ecce6a608c28 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/management/_management_client_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/management/_management_client_async.py @@ -24,7 +24,7 @@ from ..._common.utils import parse_conn_str from ..._common.constants import JWT_TOKEN_SCOPE -from ...aio._base_handler_async import ServiceBusSharedKeyCredential +from ...aio._base_handler_async import ServiceBusSharedKeyCredential, ServiceBusSASTokenCredential from ...management._generated.aio._configuration_async import ServiceBusManagementClientConfiguration from ...management._generated.aio._service_bus_management_client_async import ServiceBusManagementClient \ as ServiceBusManagementClientImpl @@ -48,12 +48,12 @@ class ServiceBusAdministrationClient: #pylint:disable=too-many-public-methods :param str fully_qualified_namespace: The fully qualified host name for the Service Bus namespace. :param credential: To authenticate to manage the entities of the ServiceBus namespace. - :type credential: Union[AsyncTokenCredential, ~azure.servicebus.aio.ServiceBusSharedKeyCredential] + :type credential: AsyncTokenCredential """ def __init__( self, fully_qualified_namespace: str, - credential: Union["AsyncTokenCredential", ServiceBusSharedKeyCredential], + credential: Union["AsyncTokenCredential"], **kwargs) -> None: self.fully_qualified_namespace = fully_qualified_namespace @@ -135,10 +135,14 @@ def from_connection_string(cls, conn_str: str, **kwargs: Any) -> "ServiceBusAdmi :param str conn_str: The connection string of the Service Bus Namespace. :rtype: ~azure.servicebus.management.aio.ServiceBusAdministrationClient """ - endpoint, shared_access_key_name, shared_access_key, _ = parse_conn_str(conn_str) + endpoint, shared_access_key_name, shared_access_key, _, token, token_expiry = parse_conn_str(conn_str) + if token and token_expiry: + credential = ServiceBusSASTokenCredential(token, token_expiry) + elif shared_access_key_name and shared_access_key: + credential = ServiceBusSharedKeyCredential(shared_access_key_name, shared_access_key) # type: ignore if "//" in endpoint: endpoint = endpoint[endpoint.index("//")+2:] - return cls(endpoint, ServiceBusSharedKeyCredential(shared_access_key_name, shared_access_key), **kwargs) + return cls(endpoint, credential, **kwargs) # type: ignore async def get_queue(self, queue_name: str, **kwargs) -> QueueProperties: """Get the properties of a queue. diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/management/_management_client.py b/sdk/servicebus/azure-servicebus/azure/servicebus/management/_management_client.py index fcb2de8e50f0..a2f680e1d2c1 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/management/_management_client.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/management/_management_client.py @@ -27,7 +27,7 @@ from .._common.constants import JWT_TOKEN_SCOPE from .._common.utils import parse_conn_str -from .._base_handler import ServiceBusSharedKeyCredential +from .._base_handler import ServiceBusSharedKeyCredential, ServiceBusSASTokenCredential from ._shared_key_policy import ServiceBusSharedKeyCredentialPolicy from ._generated._configuration import ServiceBusManagementClientConfiguration from ._generated._service_bus_management_client import ServiceBusManagementClient as ServiceBusManagementClientImpl @@ -46,11 +46,11 @@ class ServiceBusAdministrationClient: # pylint:disable=too-many-public-methods :param str fully_qualified_namespace: The fully qualified host name for the Service Bus namespace. :param credential: To authenticate to manage the entities of the ServiceBus namespace. - :type credential: Union[TokenCredential, azure.servicebus.ServiceBusSharedKeyCredential] + :type credential: TokenCredential """ def __init__(self, fully_qualified_namespace, credential, **kwargs): - # type: (str, Union[TokenCredential, ServiceBusSharedKeyCredential], Dict[str, Any]) -> None + # type: (str, TokenCredential, Dict[str, Any]) -> None self.fully_qualified_namespace = fully_qualified_namespace self._credential = credential self._endpoint = "https://" + fully_qualified_namespace @@ -130,10 +130,14 @@ def from_connection_string(cls, conn_str, **kwargs): :param str conn_str: The connection string of the Service Bus Namespace. :rtype: ~azure.servicebus.management.ServiceBusAdministrationClient """ - endpoint, shared_access_key_name, shared_access_key, _ = parse_conn_str(conn_str) + endpoint, shared_access_key_name, shared_access_key, _, token, token_expiry = parse_conn_str(conn_str) + if token and token_expiry: + credential = ServiceBusSASTokenCredential(token, token_expiry) + elif shared_access_key_name and shared_access_key: + credential = ServiceBusSharedKeyCredential(shared_access_key_name, shared_access_key) # type: ignore if "//" in endpoint: endpoint = endpoint[endpoint.index("//") + 2:] - return cls(endpoint, ServiceBusSharedKeyCredential(shared_access_key_name, shared_access_key), **kwargs) + return cls(endpoint, credential, **kwargs) def get_queue(self, queue_name, **kwargs): # type: (str, Any) -> QueueProperties diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_queues_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_queues_async.py index 36f6738adf18..c39ac80816e1 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_queues_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_queues_async.py @@ -10,7 +10,7 @@ from azure.core.exceptions import HttpResponseError, ResourceNotFoundError, ResourceExistsError from azure.servicebus.aio.management import ServiceBusAdministrationClient from azure.servicebus.management import QueueProperties -from azure.servicebus.aio import ServiceBusSharedKeyCredential +from azure.servicebus.aio._base_handler_async import ServiceBusSharedKeyCredential from azure.servicebus._common.utils import utc_now from devtools_testutils import AzureMgmtTestCase, CachedResourceGroupPreparer diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py index 351d1fe51a99..a51571caf082 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py @@ -9,6 +9,8 @@ import pytest from azure.servicebus.aio import ServiceBusClient +from azure.servicebus import Message +from azure.servicebus.aio._base_handler_async import ServiceBusSharedKeyCredential from devtools_testutils import AzureMgmtTestCase, CachedResourceGroupPreparer from servicebus_preparer import CachedServiceBusNamespacePreparer, CachedServiceBusQueuePreparer from utilities import get_logger @@ -58,3 +60,31 @@ async def test_async_sb_client_close_spawned_handlers(self, servicebus_namespace assert not sender._handler and not sender._running assert not receiver._handler and not receiver._running assert len(client._handlers) == 0 + + + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedResourceGroupPreparer() + @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') + @CachedServiceBusQueuePreparer(name_prefix='servicebustest') + async def test_client_sas_credential_async(self, + servicebus_queue, + servicebus_namespace, + servicebus_namespace_key_name, + servicebus_namespace_primary_key, + servicebus_namespace_connection_string, + **kwargs): + # This should "just work" to validate known-good. + credential = ServiceBusSharedKeyCredential(servicebus_namespace_key_name, servicebus_namespace_primary_key) + hostname = "{}.servicebus.windows.net".format(servicebus_namespace.name) + auth_uri = "sb://{}/{}".format(hostname, servicebus_queue.name) + token = (await credential.get_token(auth_uri)).token + + # Finally let's do it with SAS token + conn str + token_conn_str = "Endpoint=sb://{}/;SharedAccessSignature={};".format(hostname, token.decode()) + + client = ServiceBusClient.from_connection_string(token_conn_str) + async with client: + assert len(client._handlers) == 0 + async with client.get_queue_sender(servicebus_queue.name) as sender: + await sender.send_messages(Message("foo")) diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/test_subscriptions_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/test_subscriptions_async.py index ce55e8190715..9a11c532a726 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/test_subscriptions_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/test_subscriptions_async.py @@ -12,7 +12,8 @@ from datetime import datetime, timedelta from azure.servicebus import Message, ReceiveMode -from azure.servicebus.aio import ServiceBusClient, ServiceBusSharedKeyCredential +from azure.servicebus.aio import ServiceBusClient +from azure.servicebus.aio._base_handler_async import ServiceBusSharedKeyCredential from azure.servicebus.exceptions import ServiceBusError from azure.servicebus._common.constants import SubQueue diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/test_topic_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/test_topic_async.py index efd4cf48de65..b1b1b2c1a5e9 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/test_topic_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/test_topic_async.py @@ -14,7 +14,8 @@ from devtools_testutils import AzureMgmtTestCase, RandomNameResourceGroupPreparer, CachedResourceGroupPreparer -from azure.servicebus.aio import ServiceBusClient, ServiceBusSharedKeyCredential +from azure.servicebus.aio import ServiceBusClient +from azure.servicebus.aio._base_handler_async import ServiceBusSharedKeyCredential from azure.servicebus._common.message import Message from servicebus_preparer import ( ServiceBusNamespacePreparer, diff --git a/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_queues.py b/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_queues.py index bdb9af0f7653..a584f343af49 100644 --- a/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_queues.py +++ b/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_queues.py @@ -15,7 +15,7 @@ from azure.servicebus._common.utils import utc_now from utilities import get_logger from azure.core.exceptions import HttpResponseError, ServiceRequestError, ResourceNotFoundError, ResourceExistsError -from azure.servicebus import ServiceBusSharedKeyCredential +from azure.servicebus._base_handler import ServiceBusSharedKeyCredential from devtools_testutils import AzureMgmtTestCase, CachedResourceGroupPreparer from servicebus_preparer import ( diff --git a/sdk/servicebus/azure-servicebus/tests/test_sb_client.py b/sdk/servicebus/azure-servicebus/tests/test_sb_client.py index eaaad26a9cce..539c5c45840c 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_sb_client.py +++ b/sdk/servicebus/azure-servicebus/tests/test_sb_client.py @@ -13,7 +13,8 @@ from azure.common import AzureHttpError, AzureConflictHttpError from azure.mgmt.servicebus.models import AccessRights -from azure.servicebus import ServiceBusClient, ServiceBusSharedKeyCredential, ServiceBusSender +from azure.servicebus import ServiceBusClient, ServiceBusSender +from azure.servicebus._base_handler import ServiceBusSharedKeyCredential from azure.servicebus._common.message import Message, PeekedMessage from azure.servicebus.exceptions import ( ServiceBusError, @@ -188,3 +189,40 @@ def test_sb_client_close_spawned_handlers(self, servicebus_namespace_connection_ assert not sender._handler and not sender._running assert not receiver._handler and not receiver._running assert len(client._handlers) == 0 + + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedResourceGroupPreparer() + @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') + @CachedServiceBusQueuePreparer(name_prefix='servicebustest') + def test_client_sas_credential(self, + servicebus_queue, + servicebus_namespace, + servicebus_namespace_key_name, + servicebus_namespace_primary_key, + servicebus_namespace_connection_string, + **kwargs): + # This should "just work" to validate known-good. + credential = ServiceBusSharedKeyCredential(servicebus_namespace_key_name, servicebus_namespace_primary_key) + hostname = "{}.servicebus.windows.net".format(servicebus_namespace.name) + auth_uri = "sb://{}/{}".format(hostname, servicebus_queue.name) + token = credential.get_token(auth_uri).token + + # Finally let's do it with SAS token + conn str + token_conn_str = "Endpoint=sb://{}/;SharedAccessSignature={};".format(hostname, token.decode()) + + client = ServiceBusClient.from_connection_string(token_conn_str) + with client: + assert len(client._handlers) == 0 + with client.get_queue_sender(servicebus_queue.name) as sender: + sender.send_messages(Message("foo")) + + # This is disabled pending UAMQP fix https://github.com/Azure/azure-uamqp-python/issues/170 + # + #token_conn_str_without_se = token_conn_str.split('se=')[0] + token_conn_str.split('se=')[1].split('&')[1] + # + #client = ServiceBusClient.from_connection_string(token_conn_str_without_se) + #with client: + # assert len(client._handlers) == 0 + # with client.get_queue_sender(servicebus_queue.name) as sender: + # sender.send_messages(Message("foo")) \ No newline at end of file diff --git a/sdk/servicebus/azure-servicebus/tests/test_subscriptions.py b/sdk/servicebus/azure-servicebus/tests/test_subscriptions.py index 311897439941..70f3c91fce57 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_subscriptions.py +++ b/sdk/servicebus/azure-servicebus/tests/test_subscriptions.py @@ -11,7 +11,8 @@ import time from datetime import datetime, timedelta -from azure.servicebus import ServiceBusClient, Message, ReceiveMode, ServiceBusSharedKeyCredential +from azure.servicebus import ServiceBusClient, Message, ReceiveMode +from azure.servicebus._base_handler import ServiceBusSharedKeyCredential from azure.servicebus.exceptions import ServiceBusError from azure.servicebus._common.constants import SubQueue diff --git a/sdk/servicebus/azure-servicebus/tests/test_topic.py b/sdk/servicebus/azure-servicebus/tests/test_topic.py index ecbffa3cba1a..afa1d4b9fe96 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_topic.py +++ b/sdk/servicebus/azure-servicebus/tests/test_topic.py @@ -13,7 +13,8 @@ from devtools_testutils import AzureMgmtTestCase, RandomNameResourceGroupPreparer, CachedResourceGroupPreparer -from azure.servicebus import ServiceBusClient, ServiceBusSharedKeyCredential +from azure.servicebus import ServiceBusClient +from azure.servicebus._base_handler import ServiceBusSharedKeyCredential from azure.servicebus._common.message import Message from servicebus_preparer import ( ServiceBusNamespacePreparer,