Skip to content

Commit

Permalink
[ServiceBus] Support SAS token-via-connection-string auth, and remove…
Browse files Browse the repository at this point in the history
… ServiceBusSharedKeyCredential export (#13627)

- Remove public documentation and exports of ServiceBusSharedKeyCredential until we chose to release it across all languages.
- Support for Sas Token connection strings (tests, etc)
- Add safety net for if signature and key are both provided in connstr (inspired by .nets approach)

Co-authored-by: Rakshith Bhyravabhotla <[email protected]>
  • Loading branch information
KieranBrantnerMagee and rakshith91 authored Sep 10, 2020
1 parent d91e0f5 commit bbcc72f
Show file tree
Hide file tree
Showing 24 changed files with 290 additions and 93 deletions.
2 changes: 2 additions & 0 deletions sdk/servicebus/azure-servicebus/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* `renew_lock()` now returns the UTC datetime that the lock is set to expire at.
* `receive_deferred_messages()` can now take a single sequence number as well as a list of sequence numbers.
* Messages can now be sent twice in succession.
* Connection strings used with `from_connection_string` methods now support using the `SharedAccessSignature` key in leiu of `sharedaccesskey` and `sharedaccesskeyname`, taking the string of the properly constructed token as value.
* Internal AMQP message properties (header, footer, annotations, properties, etc) are now exposed via `Message.amqp_message`

**Breaking Changes**
Expand All @@ -31,6 +32,7 @@
* Remove `support_ordering` from `create_queue` and `QueueProperties`
* Remove `enable_subscription_partitioning` from `create_topic` and `TopicProperties`
* `get_dead_letter_[queue,subscription]_receiver()` has been removed. To connect to a dead letter queue, utilize the `sub_queue` parameter of `get_[queue,subscription]_receiver()` provided with a value from the `SubQueue` enum
* No longer export `ServiceBusSharedKeyCredential`
* Rename `entity_availability_status` to `availability_status`

## 7.0.0b5 (2020-08-10)
Expand Down
2 changes: 0 additions & 2 deletions sdk/servicebus/azure-servicebus/azure/servicebus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ._servicebus_receiver import ServiceBusReceiver
from ._servicebus_session_receiver import ServiceBusSessionReceiver
from ._servicebus_session import ServiceBusSession
from ._base_handler import ServiceBusSharedKeyCredential
from ._common.message import Message, BatchMessage, PeekedMessage, ReceivedMessage
from ._common.constants import ReceiveMode, SubQueue
from ._common.auto_lock_renewer import AutoLockRenew
Expand All @@ -32,7 +31,6 @@
'ServiceBusSessionReceiver',
'ServiceBusSession',
'ServiceBusSender',
'ServiceBusSharedKeyCredential',
'TransportType',
'AutoLockRenew'
]
106 changes: 80 additions & 26 deletions sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import uuid
import time
from datetime import timedelta
from typing import cast, Optional, Tuple, TYPE_CHECKING, Dict, Any, Callable, Type
from typing import cast, Optional, Tuple, TYPE_CHECKING, Dict, Any, Callable

try:
from urllib import quote_plus # type: ignore
Expand All @@ -18,6 +18,8 @@
from uamqp import utils
from uamqp.message import MessageProperties

from azure.core.credentials import AccessToken

from ._common._configuration import Configuration
from .exceptions import (
ServiceBusError,
Expand All @@ -41,11 +43,13 @@


def _parse_conn_str(conn_str):
# type: (str) -> Tuple[str, str, str, str]
# type: (str) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
endpoint = None
shared_access_key_name = None
shared_access_key = None
entity_path = None # type: Optional[str]
shared_access_signature = None # type: Optional[str]
shared_access_signature_expiry = None # type: Optional[int]
for element in conn_str.split(";"):
key, _, value = element.partition("=")
if key.lower() == "endpoint":
Expand All @@ -58,18 +62,35 @@ def _parse_conn_str(conn_str):
shared_access_key = value
elif key.lower() == "entitypath":
entity_path = value
if not all([endpoint, shared_access_key_name, shared_access_key]):
elif key.lower() == "sharedaccesssignature":
shared_access_signature = value
try:
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
# type: ignore
shared_access_signature_expiry = int(shared_access_signature.split('se=')[1].split('&')[0])
except (IndexError, TypeError, ValueError): # Fallback since technically expiry is optional.
# An arbitrary, absurdly large number, since you can't renew.
shared_access_signature_expiry = int(time.time() * 2)
if not (all((endpoint, shared_access_key_name, shared_access_key)) or all((endpoint, shared_access_signature))) \
or all((shared_access_key_name, shared_access_signature)): # this latter clause since we don't accept both
raise ValueError(
"Invalid connection string. Should be in the format: "
"Endpoint=sb://<FQDN>/;SharedAccessKeyName=<KeyName>;SharedAccessKey=<KeyValue>"
"\nWith alternate option of providing SharedAccessSignature instead of SharedAccessKeyName and Key"
)
entity = cast(str, entity_path)
left_slash_pos = cast(str, endpoint).find("//")
if left_slash_pos != -1:
host = cast(str, endpoint)[left_slash_pos + 2:]
else:
host = str(endpoint)
return host, str(shared_access_key_name), str(shared_access_key), entity

return (host,
str(shared_access_key_name) if shared_access_key_name else None,
str(shared_access_key) if shared_access_key else None,
entity,
str(shared_access_signature) if shared_access_signature else None,
shared_access_signature_expiry)


def _generate_sas_token(uri, policy, key, expiry=None):
Expand All @@ -90,29 +111,27 @@ def _generate_sas_token(uri, policy, key, expiry=None):
return _AccessToken(token=token, expires_on=abs_expiry)


def _convert_connection_string_to_kwargs(conn_str, shared_key_credential_type, **kwargs):
# type: (str, Type, Any) -> Dict[str, Any]
host, policy, key, entity_in_conn_str = _parse_conn_str(conn_str)
queue_name = kwargs.get("queue_name")
topic_name = kwargs.get("topic_name")
if not (queue_name or topic_name or entity_in_conn_str):
raise ValueError("Entity name is missing. Please specify `queue_name` or `topic_name`"
" or use a connection string including the entity information.")

if queue_name and topic_name:
raise ValueError("`queue_name` and `topic_name` can not be specified simultaneously.")

entity_in_kwargs = queue_name or topic_name
if entity_in_conn_str and entity_in_kwargs and (entity_in_conn_str != entity_in_kwargs):
raise ServiceBusAuthenticationError(
"Entity names do not match, the entity name in connection string is {};"
" the entity name in parameter is {}.".format(entity_in_conn_str, entity_in_kwargs)
)
class ServiceBusSASTokenCredential(object):
"""The shared access token credential used for authentication.
:param str token: The shared access token string
:param int expiry: The epoch timestamp
"""
def __init__(self, token, expiry):
# type: (str, int) -> None
"""
:param str token: The shared access token string
:param float expiry: The epoch timestamp
"""
self.token = token
self.expiry = expiry
self.token_type = b"servicebus.windows.net:sastoken"

kwargs["fully_qualified_namespace"] = host
kwargs["entity_name"] = entity_in_conn_str or entity_in_kwargs
kwargs["credential"] = shared_key_credential_type(policy, key)
return kwargs
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (str, Any) -> AccessToken
"""
This method is automatically called when token is about to expire.
"""
return AccessToken(self.token, self.expiry)


class ServiceBusSharedKeyCredential(object):
Expand Down Expand Up @@ -158,6 +177,41 @@ def __init__(
self._auth_uri = None
self._properties = create_properties(self._config.user_agent)

@classmethod
def _convert_connection_string_to_kwargs(cls, conn_str, **kwargs):
# type: (str, Any) -> Dict[str, Any]
host, policy, key, entity_in_conn_str, token, token_expiry = _parse_conn_str(conn_str)
queue_name = kwargs.get("queue_name")
topic_name = kwargs.get("topic_name")
if not (queue_name or topic_name or entity_in_conn_str):
raise ValueError("Entity name is missing. Please specify `queue_name` or `topic_name`"
" or use a connection string including the entity information.")

if queue_name and topic_name:
raise ValueError("`queue_name` and `topic_name` can not be specified simultaneously.")

entity_in_kwargs = queue_name or topic_name
if entity_in_conn_str and entity_in_kwargs and (entity_in_conn_str != entity_in_kwargs):
raise ServiceBusAuthenticationError(
"Entity names do not match, the entity name in connection string is {};"
" the entity name in parameter is {}.".format(entity_in_conn_str, entity_in_kwargs)
)

kwargs["fully_qualified_namespace"] = host
kwargs["entity_name"] = entity_in_conn_str or entity_in_kwargs
# This has to be defined seperately to support sync vs async credentials.
kwargs["credential"] = cls._create_credential_from_connection_string_parameters(token,
token_expiry,
policy,
key)
return kwargs

@classmethod
def _create_credential_from_connection_string_parameters(cls, token, token_expiry, policy, key):
if token and token_expiry:
return ServiceBusSASTokenCredential(token, token_expiry)
return ServiceBusSharedKeyCredential(policy, key)

def __enter__(self):
self._open_with_retry()
return self
Expand Down
40 changes: 32 additions & 8 deletions sdk/servicebus/azure-servicebus/azure/servicebus/_common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import logging
import functools
import platform
from typing import Optional, Dict
import time
from typing import Optional, Dict, Tuple
try:
from urlparse import urlparse
except ImportError:
Expand Down Expand Up @@ -63,11 +64,15 @@ def utc_now():
return datetime.datetime.now(tz=TZ_UTC)


# This parse_conn_str is used for mgmt, the other in base_handler for handlers. Should be unified.
def parse_conn_str(conn_str):
endpoint = None
shared_access_key_name = None
shared_access_key = None
entity_path = None
# type: (str) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
endpoint = ''
shared_access_key_name = None # type: Optional[str]
shared_access_key = None # type: Optional[str]
entity_path = ''
shared_access_signature = None # type: Optional[str]
shared_access_signature_expiry = None # type: Optional[int]
for element in conn_str.split(';'):
key, _, value = element.partition('=')
if key.lower() == 'endpoint':
Expand All @@ -78,9 +83,28 @@ def parse_conn_str(conn_str):
shared_access_key = value
elif key.lower() == 'entitypath':
entity_path = value
if not all([endpoint, shared_access_key_name, shared_access_key]):
raise ValueError("Invalid connection string")
return endpoint, shared_access_key_name, shared_access_key, entity_path
elif key.lower() == "sharedaccesssignature":
shared_access_signature = value
try:
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
# type: ignore
shared_access_signature_expiry = int(shared_access_signature.split('se=')[1].split('&')[0])
except (IndexError, TypeError, ValueError): # Fallback since technically expiry is optional.
# An arbitrary, absurdly large number, since you can't renew.
shared_access_signature_expiry = int(time.time() * 2)
if not (all((endpoint, shared_access_key_name, shared_access_key)) or all((endpoint, shared_access_signature))) \
or all((shared_access_key_name, shared_access_signature)): # this latter clause since we don't accept both
raise ValueError(
"Invalid connection string. Should be in the format: "
"Endpoint=sb://<FQDN>/;SharedAccessKeyName=<KeyName>;SharedAccessKey=<KeyValue>"
"\nWith alternate option of providing SharedAccessSignature instead of SharedAccessKeyName and Key"
)
return (endpoint,
str(shared_access_key_name) if shared_access_key_name else None,
str(shared_access_key) if shared_access_key else None,
entity_path,
str(shared_access_signature) if shared_access_signature else None,
shared_access_signature_expiry)


def build_uri(address, entity):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

import uamqp

from ._base_handler import _parse_conn_str, ServiceBusSharedKeyCredential, BaseHandler
from ._base_handler import (
_parse_conn_str,
ServiceBusSharedKeyCredential,
ServiceBusSASTokenCredential,
BaseHandler)
from ._servicebus_sender import ServiceBusSender
from ._servicebus_receiver import ServiceBusReceiver
from ._servicebus_session_receiver import ServiceBusSessionReceiver
Expand All @@ -33,8 +37,8 @@ class ServiceBusClient(object):
The namespace format is: `<yournamespace>.servicebus.windows.net`.
:param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which
implements a particular interface for getting tokens. It accepts
:class:`ServiceBusSharedKeyCredential<azure.servicebus.ServiceBusSharedKeyCredential>`, or credential objects
generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method.
credential objects generated by the azure-identity library and objects that implement the
`get_token(self, *scopes)` method.
:keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`.
:keyword transport_type: The type of transport protocol that will be used for communicating with
the Service Bus service. Default is `TransportType.Amqp`.
Expand Down Expand Up @@ -153,11 +157,15 @@ def from_connection_string(
:caption: Create a new instance of the ServiceBusClient from connection string.
"""
host, policy, key, entity_in_conn_str = _parse_conn_str(conn_str)
host, policy, key, entity_in_conn_str, token, token_expiry = _parse_conn_str(conn_str)
if token and token_expiry:
credential = ServiceBusSASTokenCredential(token, token_expiry)
elif policy and key:
credential = ServiceBusSharedKeyCredential(policy, key) # type: ignore
return cls(
fully_qualified_namespace=host,
entity_name=entity_in_conn_str or kwargs.pop("entity_name", None),
credential=ServiceBusSharedKeyCredential(policy, key), # type: ignore
credential=credential, # type: ignore
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from uamqp.constants import SenderSettleMode
from uamqp.authentication.common import AMQPAuth

from ._base_handler import BaseHandler, ServiceBusSharedKeyCredential, _convert_connection_string_to_kwargs
from ._base_handler import BaseHandler
from ._common.utils import create_authentication
from ._common.message import PeekedMessage, ReceivedMessage
from ._common.constants import (
Expand Down Expand Up @@ -56,8 +56,8 @@ class ServiceBusReceiver(BaseHandler, ReceiverMixin): # pylint: disable=too-man
The namespace format is: `<yournamespace>.servicebus.windows.net`.
:param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which
implements a particular interface for getting tokens. It accepts
:class:`ServiceBusSharedKeyCredential<azure.servicebus.ServiceBusSharedKeyCredential>`, or credential objects
generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method.
:class: credential objects generated by the azure-identity library and objects that implement the
`get_token(self, *scopes)` method.
:keyword str queue_name: The path of specific Service Bus Queue the client connects to.
:keyword str topic_name: The path of specific Service Bus Topic which contains the Subscription
the client connects to.
Expand Down Expand Up @@ -363,9 +363,8 @@ def from_connection_string(
:caption: Create a new instance of the ServiceBusReceiver from connection string.
"""
constructor_args = _convert_connection_string_to_kwargs(
constructor_args = cls._convert_connection_string_to_kwargs(
conn_str,
ServiceBusSharedKeyCredential,
**kwargs
)
if kwargs.get("queue_name") and kwargs.get("subscription_name"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from uamqp import SendClient, types
from uamqp.authentication.common import AMQPAuth

from ._base_handler import BaseHandler, ServiceBusSharedKeyCredential, _convert_connection_string_to_kwargs
from ._base_handler import BaseHandler
from ._common import mgmt_handlers
from ._common.message import Message, BatchMessage
from .exceptions import (
Expand Down Expand Up @@ -97,8 +97,8 @@ class ServiceBusSender(BaseHandler, SenderMixin):
The namespace format is: `<yournamespace>.servicebus.windows.net`.
:param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which
implements a particular interface for getting tokens. It accepts
:class:`ServiceBusSharedKeyCredential<azure.servicebus.ServiceBusSharedKeyCredential>`, or credential objects
generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method.
:class: credential objects generated by the azure-identity library and objects that implement the
`get_token(self, *scopes)` method.
:keyword str queue_name: The path of specific Service Bus Queue the client connects to.
:keyword str topic_name: The path of specific Service Bus Topic the client connects to.
:keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`.
Expand Down Expand Up @@ -293,9 +293,8 @@ def from_connection_string(
:caption: Create a new instance of the ServiceBusSender from connection string.
"""
constructor_args = _convert_connection_string_to_kwargs(
constructor_args = cls._convert_connection_string_to_kwargs(
conn_str,
ServiceBusSharedKeyCredential,
**kwargs
)
return cls(**constructor_args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class ServiceBusSessionReceiver(ServiceBusReceiver, SessionReceiverMixin):
The namespace format is: `<yournamespace>.servicebus.windows.net`.
:param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which
implements a particular interface for getting tokens. It accepts
:class:`ServiceBusSharedKeyCredential<azure.servicebus.ServiceBusSharedKeyCredential>`, or credential objects
generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method.
:class: credential objects generated by the azure-identity library and objects that implement the
`get_token(self, *scopes)` method.
:keyword str queue_name: The path of specific Service Bus Queue the client connects to.
:keyword str topic_name: The path of specific Service Bus Topic which contains the Subscription
the client connects to.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# license information.
# -------------------------------------------------------------------------
from ._async_message import ReceivedMessage
from ._base_handler_async import ServiceBusSharedKeyCredential
from ._servicebus_sender_async import ServiceBusSender
from ._servicebus_receiver_async import ServiceBusReceiver
from ._servicebus_session_receiver_async import ServiceBusSessionReceiver
Expand All @@ -19,6 +18,5 @@
'ServiceBusReceiver',
'ServiceBusSessionReceiver',
'ServiceBusSession',
'ServiceBusSharedKeyCredential',
'AutoLockRenew'
]
Loading

0 comments on commit bbcc72f

Please sign in to comment.