Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ServiceBus] Support SAS token-via-connection-string auth, and remove ServiceBusSharedKeyCredential export #13627

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sdk/servicebus/azure-servicebus/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

**New Features**
* Messages can now be sent twice in succession.
* Connection strings now support
KieranBrantnerMagee marked this conversation as resolved.
Show resolved Hide resolved
* Connection strings used with `from_connection_string` methods now support using the `SharedAccessSignature` key in leiu of `sharedaccesskey` and `sharedaccesskeyname`, taking the string of the properly constructed token as value.
* Internal AMQP message properties (header, footer, annotations, properties, etc) are now exposed via `Message.amqp_message`

**Breaking Changes**
Expand All @@ -13,6 +15,7 @@
* Sending a message twice will no longer result in a MessageAlreadySettled exception.
* `ServiceBusClient.close()` now closes spawned senders and receivers.
* Attempting to initialize a sender or receiver with a different connection string entity and specified entity (e.g. `queue_name`) will result in an AuthenticationError
* No longer export `ServiceBusSharedKeyCredential`

## 7.0.0b5 (2020-08-10)

Expand Down
2 changes: 0 additions & 2 deletions sdk/servicebus/azure-servicebus/azure/servicebus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ._servicebus_receiver import ServiceBusReceiver
from ._servicebus_session_receiver import ServiceBusSessionReceiver
from ._servicebus_session import ServiceBusSession
from ._base_handler import ServiceBusSharedKeyCredential
from ._common.message import Message, BatchMessage, PeekMessage, ReceivedMessage
from ._common.constants import ReceiveSettleMode, NEXT_AVAILABLE
from ._common.auto_lock_renewer import AutoLockRenew
Expand All @@ -32,7 +31,6 @@
'ServiceBusSessionReceiver',
'ServiceBusSession',
'ServiceBusSender',
'ServiceBusSharedKeyCredential',
'TransportType',
'AutoLockRenew'
]
99 changes: 75 additions & 24 deletions sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ASSOCIATEDLINKPROPERTYNAME
)

from azure.core.credentials import AccessToken
if TYPE_CHECKING:
from azure.core.credentials import TokenCredential

Expand All @@ -46,6 +47,8 @@ def _parse_conn_str(conn_str):
shared_access_key_name = None
shared_access_key = None
entity_path = None # type: Optional[str]
shared_access_signature = None # type: Optional[str]
shared_access_signature_expiry = None # type: Optional[int]
for element in conn_str.split(";"):
key, _, value = element.partition("=")
if key.lower() == "endpoint":
Expand All @@ -58,7 +61,16 @@ def _parse_conn_str(conn_str):
shared_access_key = value
elif key.lower() == "entitypath":
entity_path = value
if not all([endpoint, shared_access_key_name, shared_access_key]):
elif key.lower() == "sharedaccesssignature":
shared_access_signature = value
try:
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
# type: ignore
shared_access_signature_expiry = int(shared_access_signature.split('se=')[1].split('&')[0])
except (IndexError, TypeError, ValueError): # Fallback since technically expiry is optional.
# An arbitrary, absurdly large number, since you can't renew.
shared_access_signature_expiry = int(time.time() * 2)
Copy link
Contributor

@yunhaoling yunhaoling Sep 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super confident on large number being implemented as time.time() * 2.

Have we tested this case -- no expiry provided in signature and we default to a very large number?

The concern I have here is that as you know our uamqp is built upon C, whether this would lead to integer overflow on certain platforms, making expiry being some negative value.

Copy link
Member Author

@KieranBrantnerMagee KieranBrantnerMagee Sep 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wisdom beyond your years. My inclination would be to use some vestige of INT_MAX. I'll poke around, but shout if you know off the cuff the incantation that I want to properly align with whatever is being done in the AMQP level.

(Worded more clearly; will I be shooting myself in the foot if I use sys.maxint as far as you know of uamqp?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per our OOB discussion: Writing this test ended up exposing the following uamqp bug

Added issue for future investigation, noting this in the test stub that we can reenable once it's fixed.

if not (all((endpoint, shared_access_key_name, shared_access_key)) or all((endpoint, shared_access_signature))):
raise ValueError(
"Invalid connection string. Should be in the format: "
"Endpoint=sb://<FQDN>/;SharedAccessKeyName=<KeyName>;SharedAccessKey=<KeyValue>"
Expand All @@ -69,7 +81,13 @@ def _parse_conn_str(conn_str):
host = cast(str, endpoint)[left_slash_pos + 2:]
else:
host = str(endpoint)
return host, str(shared_access_key_name), str(shared_access_key), entity

return (host,
str(shared_access_key_name) if shared_access_key_name else None,
str(shared_access_key) if shared_access_key else None,
entity,
str(shared_access_signature) if shared_access_signature else None,
shared_access_signature_expiry)


def _generate_sas_token(uri, policy, key, expiry=None):
Expand All @@ -90,29 +108,27 @@ def _generate_sas_token(uri, policy, key, expiry=None):
return _AccessToken(token=token, expires_on=abs_expiry)
KieranBrantnerMagee marked this conversation as resolved.
Show resolved Hide resolved


def _convert_connection_string_to_kwargs(conn_str, shared_key_credential_type, **kwargs):
# type: (str, Type, Any) -> Dict[str, Any]
host, policy, key, entity_in_conn_str = _parse_conn_str(conn_str)
queue_name = kwargs.get("queue_name")
topic_name = kwargs.get("topic_name")
if not (queue_name or topic_name or entity_in_conn_str):
raise ValueError("Entity name is missing. Please specify `queue_name` or `topic_name`"
" or use a connection string including the entity information.")

if queue_name and topic_name:
raise ValueError("`queue_name` and `topic_name` can not be specified simultaneously.")

entity_in_kwargs = queue_name or topic_name
if entity_in_conn_str and entity_in_kwargs and (entity_in_conn_str != entity_in_kwargs):
raise ServiceBusAuthenticationError(
"Entity names do not match, the entity name in connection string is {};"
" the entity name in parameter is {}.".format(entity_in_conn_str, entity_in_kwargs)
)
class ServiceBusSASTokenCredential(object):
"""The shared access token credential used for authentication.
:param str token: The shared access token string
:param int expiry: The epoch timestamp
"""
def __init__(self, token, expiry):
# type: (str, int) -> None
"""
:param str token: The shared access token string
:param float expiry: The epoch timestamp
"""
self.token = token
self.expiry = expiry
self.token_type = b"servicebus.windows.net:sastoken"

kwargs["fully_qualified_namespace"] = host
kwargs["entity_name"] = entity_in_conn_str or entity_in_kwargs
kwargs["credential"] = shared_key_credential_type(policy, key)
return kwargs
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (str, Any) -> AccessToken
"""
This method is automatically called when token is about to expire.
"""
return AccessToken(self.token, self.expiry)


class ServiceBusSharedKeyCredential(object):
Expand Down Expand Up @@ -158,6 +174,41 @@ def __init__(
self._auth_uri = None
self._properties = create_properties(self._config.user_agent)


def _convert_connection_string_to_kwargs(self, conn_str, **kwargs):
# type: (str, Type, Any) -> Dict[str, Any]
host, policy, key, entity_in_conn_str, token, token_expiry = _parse_conn_str(conn_str)
queue_name = kwargs.get("queue_name")
topic_name = kwargs.get("topic_name")
if not (queue_name or topic_name or entity_in_conn_str):
raise ValueError("Entity name is missing. Please specify `queue_name` or `topic_name`"
" or use a connection string including the entity information.")

if queue_name and topic_name:
raise ValueError("`queue_name` and `topic_name` can not be specified simultaneously.")

entity_in_kwargs = queue_name or topic_name
if entity_in_conn_str and entity_in_kwargs and (entity_in_conn_str != entity_in_kwargs):
raise ServiceBusAuthenticationError(
"Entity names do not match, the entity name in connection string is {};"
" the entity name in parameter is {}.".format(entity_in_conn_str, entity_in_kwargs)
)

kwargs["fully_qualified_namespace"] = host
kwargs["entity_name"] = entity_in_conn_str or entity_in_kwargs
# This has to be defined seperately to support sync vs async credentials.
kwargs["credential"] = self._create_credential_from_connection_string_parameters(token,
token_expiry,
policy,
key)
return kwargs

def _create_credential_from_connection_string_parameters(self, token, token_expiry, policy, key):
if token and token_expiry:
return ServiceBusSASTokenCredential(token, token_expiry)
elif policy and key:
return ServiceBusSharedKeyCredential(policy, key)
KieranBrantnerMagee marked this conversation as resolved.
Show resolved Hide resolved

def __enter__(self):
self._open_with_retry()
return self
Expand Down
22 changes: 20 additions & 2 deletions sdk/servicebus/azure-servicebus/azure/servicebus/_common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,15 @@ def utc_now():
return datetime.datetime.now(tz=TZ_UTC)


# This parse_conn_str is used for mgmt, the other in base_handler for handlers. Should be unified.
def parse_conn_str(conn_str):
# type: (str) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
endpoint = None
shared_access_key_name = None
shared_access_key = None
entity_path = None
shared_access_signature = None # type: Optional[str]
shared_access_signature_expiry = None # type: Optional[int]
for element in conn_str.split(';'):
key, _, value = element.partition('=')
if key.lower() == 'endpoint':
Expand All @@ -78,9 +82,23 @@ def parse_conn_str(conn_str):
shared_access_key = value
elif key.lower() == 'entitypath':
entity_path = value
if not all([endpoint, shared_access_key_name, shared_access_key]):
elif key.lower() == "sharedaccesssignature":
shared_access_signature = value
try:
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
# type: ignore
shared_access_signature_expiry = int(shared_access_signature.split('se=')[1].split('&')[0])
except (IndexError, TypeError, ValueError): # Fallback since technically expiry is optional.
# An arbitrary, absurdly large number, since you can't renew.
shared_access_signature_expiry = int(time.time() * 2)
if not (all((endpoint, shared_access_key_name, shared_access_key)) or all((endpoint, shared_access_signature))):
raise ValueError("Invalid connection string")
return endpoint, shared_access_key_name, shared_access_key, entity_path
return (endpoint,
str(shared_access_key_name) if shared_access_key_name else None,
str(shared_access_key) if shared_access_key else None,
entity_path,
str(shared_access_signature) if shared_access_signature else None,
shared_access_signature_expiry)


def build_uri(address, entity):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

import uamqp

from ._base_handler import _parse_conn_str, ServiceBusSharedKeyCredential, BaseHandler
from ._base_handler import (
_parse_conn_str,
ServiceBusSharedKeyCredential,
ServiceBusSASTokenCredential,
BaseHandler)
from ._servicebus_sender import ServiceBusSender
from ._servicebus_receiver import ServiceBusReceiver
from ._servicebus_session_receiver import ServiceBusSessionReceiver
Expand All @@ -32,8 +36,8 @@ class ServiceBusClient(object):
The namespace format is: `<yournamespace>.servicebus.windows.net`.
:param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which
implements a particular interface for getting tokens. It accepts
:class:`ServiceBusSharedKeyCredential<azure.servicebus.ServiceBusSharedKeyCredential>`, or credential objects
generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method.
:class: credential objects generated by the azure-identity library and objects that implement the
`get_token(self, *scopes)` method.
:keyword str entity_name: Optional entity name, this can be the name of Queue or Topic.
It must be specified if the credential is for specific Queue or Topic.
:keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`.
Expand Down Expand Up @@ -145,11 +149,15 @@ def from_connection_string(
:caption: Create a new instance of the ServiceBusClient from connection string.

"""
host, policy, key, entity_in_conn_str = _parse_conn_str(conn_str)
host, policy, key, entity_in_conn_str, token, token_expiry = _parse_conn_str(conn_str)
if token and token_expiry:
credential = ServiceBusSASTokenCredential(token, token_expiry)
elif policy and key:
credential = ServiceBusSharedKeyCredential(policy, key)
KieranBrantnerMagee marked this conversation as resolved.
Show resolved Hide resolved
return cls(
fully_qualified_namespace=host,
entity_name=entity_in_conn_str or kwargs.pop("entity_name", None),
credential=ServiceBusSharedKeyCredential(policy, key), # type: ignore
credential=credential, # type: ignore
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from uamqp.constants import SenderSettleMode
from uamqp.authentication.common import AMQPAuth

from ._base_handler import BaseHandler, ServiceBusSharedKeyCredential, _convert_connection_string_to_kwargs
from ._base_handler import BaseHandler, ServiceBusSharedKeyCredential
from ._common.utils import create_authentication
from ._common.message import PeekMessage, ReceivedMessage
from ._common.constants import (
Expand Down Expand Up @@ -54,8 +54,8 @@ class ServiceBusReceiver(BaseHandler, ReceiverMixin): # pylint: disable=too-man
The namespace format is: `<yournamespace>.servicebus.windows.net`.
:param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which
implements a particular interface for getting tokens. It accepts
:class:`ServiceBusSharedKeyCredential<azure.servicebus.ServiceBusSharedKeyCredential>`, or credential objects
generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method.
:class: credential objects generated by the azure-identity library and objects that implement the
`get_token(self, *scopes)` method.
:keyword str queue_name: The path of specific Service Bus Queue the client connects to.
:keyword str topic_name: The path of specific Service Bus Topic which contains the Subscription
the client connects to.
Expand Down Expand Up @@ -365,9 +365,8 @@ def from_connection_string(
:caption: Create a new instance of the ServiceBusReceiver from connection string.

"""
constructor_args = _convert_connection_string_to_kwargs(
constructor_args = self._convert_connection_string_to_kwargs(
conn_str,
ServiceBusSharedKeyCredential,
**kwargs
)
if kwargs.get("queue_name") and kwargs.get("subscription_name"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from uamqp import SendClient, types
from uamqp.authentication.common import AMQPAuth

from ._base_handler import BaseHandler, ServiceBusSharedKeyCredential, _convert_connection_string_to_kwargs
from ._base_handler import BaseHandler, ServiceBusSharedKeyCredential
from ._common import mgmt_handlers
from ._common.message import Message, BatchMessage
from .exceptions import (
Expand Down Expand Up @@ -97,8 +97,8 @@ class ServiceBusSender(BaseHandler, SenderMixin):
The namespace format is: `<yournamespace>.servicebus.windows.net`.
:param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which
implements a particular interface for getting tokens. It accepts
:class:`ServiceBusSharedKeyCredential<azure.servicebus.ServiceBusSharedKeyCredential>`, or credential objects
generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method.
:class: credential objects generated by the azure-identity library and objects that implement the
`get_token(self, *scopes)` method.
:keyword str queue_name: The path of specific Service Bus Queue the client connects to.
:keyword str topic_name: The path of specific Service Bus Topic the client connects to.
:keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`.
Expand Down Expand Up @@ -297,9 +297,8 @@ def from_connection_string(
:caption: Create a new instance of the ServiceBusSender from connection string.

"""
constructor_args = _convert_connection_string_to_kwargs(
constructor_args = self._convert_connection_string_to_kwargs(
conn_str,
ServiceBusSharedKeyCredential,
**kwargs
)
return cls(**constructor_args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class ServiceBusSessionReceiver(ServiceBusReceiver, SessionReceiverMixin):
The namespace format is: `<yournamespace>.servicebus.windows.net`.
:param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which
implements a particular interface for getting tokens. It accepts
:class:`ServiceBusSharedKeyCredential<azure.servicebus.ServiceBusSharedKeyCredential>`, or credential objects
generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method.
:class: credential objects generated by the azure-identity library and objects that implement the
`get_token(self, *scopes)` method.
:keyword str queue_name: The path of specific Service Bus Queue the client connects to.
:keyword str topic_name: The path of specific Service Bus Topic which contains the Subscription
the client connects to.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# license information.
# -------------------------------------------------------------------------
from ._async_message import ReceivedMessage
from ._base_handler_async import ServiceBusSharedKeyCredential
from ._servicebus_sender_async import ServiceBusSender
from ._servicebus_receiver_async import ServiceBusReceiver
from ._servicebus_session_receiver_async import ServiceBusSessionReceiver
Expand All @@ -19,6 +18,5 @@
'ServiceBusReceiver',
'ServiceBusSessionReceiver',
'ServiceBusSession',
'ServiceBusSharedKeyCredential',
'AutoLockRenew'
]
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,34 @@
_create_servicebus_exception
)

from azure.core.credentials import AccessToken
if TYPE_CHECKING:
from azure.core.credentials import TokenCredential, AccessToken
from azure.core.credentials import TokenCredential

_LOGGER = logging.getLogger(__name__)


class ServiceBusSASTokenCredential(object):
"""The shared access token credential used for authentication.
:param str token: The shared access token string
:param int expiry: The epoch timestamp
"""
def __init__(self, token: str, expiry: int) -> None:
"""
:param str token: The shared access token string
:param int expiry: The epoch timestamp
"""
self.token = token
self.expiry = expiry
self.token_type = b"servicebus.windows.net:sastoken"

async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument
"""
This method is automatically called when token is about to expire.
"""
return AccessToken(self.token, self.expiry)


class ServiceBusSharedKeyCredential(object):
"""The shared access key credential used for authentication.

Expand Down Expand Up @@ -69,6 +91,15 @@ def __init__(
self._auth_uri = None
self._properties = create_properties(self._config.user_agent)

def _convert_connection_string_to_kwargs(self, conn_str, **kwargs):
return BaseHandlerSync._convert_connection_string_to_kwargs(self, conn_str, kwargs)

def _create_credential_from_connection_string_parameters(self, token, token_expiry, policy, key):
if token and token_expiry:
return ServiceBusSASTokenCredential(token, token_expiry)
elif policy and key:
return ServiceBusSharedKeyCredential(policy, key)

async def __aenter__(self):
await self._open_with_retry()
return self
Expand Down
Loading