Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ServiceBus] Support SAS token-via-connection-string auth, and remove ServiceBusSharedKeyCredential export #13627

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/servicebus/azure-servicebus/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* `renew_lock()` now returns the UTC datetime that the lock is set to expire at.
* `receive_deferred_messages()` can now take a single sequence number as well as a list of sequence numbers.
* Messages can now be sent twice in succession.
* Connection strings used with `from_connection_string` methods now support using the `SharedAccessSignature` key in leiu of `sharedaccesskey` and `sharedaccesskeyname`, taking the string of the properly constructed token as value.
* Internal AMQP message properties (header, footer, annotations, properties, etc) are now exposed via `Message.amqp_message`

**Breaking Changes**
Expand All @@ -31,6 +32,7 @@
* Remove `support_ordering` from `create_queue` and `QueueProperties`
* Remove `enable_subscription_partitioning` from `create_topic` and `TopicProperties`
* `get_dead_letter_[queue,subscription]_receiver()` has been removed. To connect to a dead letter queue, utilize the `sub_queue` parameter of `get_[queue,subscription]_receiver()` provided with a value from the `SubQueue` enum
* No longer export `ServiceBusSharedKeyCredential`
* Rename `entity_availability_status` to `availability_status`

## 7.0.0b5 (2020-08-10)
Expand Down
2 changes: 0 additions & 2 deletions sdk/servicebus/azure-servicebus/azure/servicebus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ._servicebus_receiver import ServiceBusReceiver
from ._servicebus_session_receiver import ServiceBusSessionReceiver
from ._servicebus_session import ServiceBusSession
from ._base_handler import ServiceBusSharedKeyCredential
from ._common.message import Message, BatchMessage, PeekedMessage, ReceivedMessage
from ._common.constants import ReceiveMode, SubQueue
from ._common.auto_lock_renewer import AutoLockRenew
Expand All @@ -32,7 +31,6 @@
'ServiceBusSessionReceiver',
'ServiceBusSession',
'ServiceBusSender',
'ServiceBusSharedKeyCredential',
'TransportType',
'AutoLockRenew'
]
106 changes: 80 additions & 26 deletions sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import uuid
import time
from datetime import timedelta
from typing import cast, Optional, Tuple, TYPE_CHECKING, Dict, Any, Callable, Type
from typing import cast, Optional, Tuple, TYPE_CHECKING, Dict, Any, Callable

try:
from urllib import quote_plus # type: ignore
Expand All @@ -18,6 +18,8 @@
from uamqp import utils
from uamqp.message import MessageProperties

from azure.core.credentials import AccessToken

from ._common._configuration import Configuration
from .exceptions import (
ServiceBusError,
Expand All @@ -41,11 +43,13 @@


def _parse_conn_str(conn_str):
# type: (str) -> Tuple[str, str, str, str]
# type: (str) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
endpoint = None
shared_access_key_name = None
shared_access_key = None
entity_path = None # type: Optional[str]
shared_access_signature = None # type: Optional[str]
shared_access_signature_expiry = None # type: Optional[int]
for element in conn_str.split(";"):
key, _, value = element.partition("=")
if key.lower() == "endpoint":
Expand All @@ -58,18 +62,35 @@ def _parse_conn_str(conn_str):
shared_access_key = value
elif key.lower() == "entitypath":
entity_path = value
if not all([endpoint, shared_access_key_name, shared_access_key]):
elif key.lower() == "sharedaccesssignature":
shared_access_signature = value
try:
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
# type: ignore
shared_access_signature_expiry = int(shared_access_signature.split('se=')[1].split('&')[0])
except (IndexError, TypeError, ValueError): # Fallback since technically expiry is optional.
# An arbitrary, absurdly large number, since you can't renew.
shared_access_signature_expiry = int(time.time() * 2)
Copy link
Contributor

@yunhaoling yunhaoling Sep 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super confident on large number being implemented as time.time() * 2.

Have we tested this case -- no expiry provided in signature and we default to a very large number?

The concern I have here is that as you know our uamqp is built upon C, whether this would lead to integer overflow on certain platforms, making expiry being some negative value.

Copy link
Member Author

@KieranBrantnerMagee KieranBrantnerMagee Sep 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wisdom beyond your years. My inclination would be to use some vestige of INT_MAX. I'll poke around, but shout if you know off the cuff the incantation that I want to properly align with whatever is being done in the AMQP level.

(Worded more clearly; will I be shooting myself in the foot if I use sys.maxint as far as you know of uamqp?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per our OOB discussion: Writing this test ended up exposing the following uamqp bug

Added issue for future investigation, noting this in the test stub that we can reenable once it's fixed.

if not (all((endpoint, shared_access_key_name, shared_access_key)) or all((endpoint, shared_access_signature))) \
or all((shared_access_key_name, shared_access_signature)): # this latter clause since we don't accept both
raise ValueError(
"Invalid connection string. Should be in the format: "
"Endpoint=sb://<FQDN>/;SharedAccessKeyName=<KeyName>;SharedAccessKey=<KeyValue>"
"\nWith alternate option of providing SharedAccessSignature instead of SharedAccessKeyName and Key"
)
entity = cast(str, entity_path)
left_slash_pos = cast(str, endpoint).find("//")
if left_slash_pos != -1:
host = cast(str, endpoint)[left_slash_pos + 2:]
else:
host = str(endpoint)
return host, str(shared_access_key_name), str(shared_access_key), entity

return (host,
str(shared_access_key_name) if shared_access_key_name else None,
str(shared_access_key) if shared_access_key else None,
entity,
str(shared_access_signature) if shared_access_signature else None,
shared_access_signature_expiry)


def _generate_sas_token(uri, policy, key, expiry=None):
Expand All @@ -90,29 +111,27 @@ def _generate_sas_token(uri, policy, key, expiry=None):
return _AccessToken(token=token, expires_on=abs_expiry)
KieranBrantnerMagee marked this conversation as resolved.
Show resolved Hide resolved


def _convert_connection_string_to_kwargs(conn_str, shared_key_credential_type, **kwargs):
# type: (str, Type, Any) -> Dict[str, Any]
host, policy, key, entity_in_conn_str = _parse_conn_str(conn_str)
queue_name = kwargs.get("queue_name")
topic_name = kwargs.get("topic_name")
if not (queue_name or topic_name or entity_in_conn_str):
raise ValueError("Entity name is missing. Please specify `queue_name` or `topic_name`"
" or use a connection string including the entity information.")

if queue_name and topic_name:
raise ValueError("`queue_name` and `topic_name` can not be specified simultaneously.")

entity_in_kwargs = queue_name or topic_name
if entity_in_conn_str and entity_in_kwargs and (entity_in_conn_str != entity_in_kwargs):
raise ServiceBusAuthenticationError(
"Entity names do not match, the entity name in connection string is {};"
" the entity name in parameter is {}.".format(entity_in_conn_str, entity_in_kwargs)
)
class ServiceBusSASTokenCredential(object):
"""The shared access token credential used for authentication.
:param str token: The shared access token string
:param int expiry: The epoch timestamp
"""
def __init__(self, token, expiry):
# type: (str, int) -> None
"""
:param str token: The shared access token string
:param float expiry: The epoch timestamp
"""
self.token = token
self.expiry = expiry
self.token_type = b"servicebus.windows.net:sastoken"

kwargs["fully_qualified_namespace"] = host
kwargs["entity_name"] = entity_in_conn_str or entity_in_kwargs
kwargs["credential"] = shared_key_credential_type(policy, key)
return kwargs
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (str, Any) -> AccessToken
"""
This method is automatically called when token is about to expire.
"""
return AccessToken(self.token, self.expiry)


class ServiceBusSharedKeyCredential(object):
Expand Down Expand Up @@ -158,6 +177,41 @@ def __init__(
self._auth_uri = None
self._properties = create_properties(self._config.user_agent)

@classmethod
def _convert_connection_string_to_kwargs(cls, conn_str, **kwargs):
# type: (str, Any) -> Dict[str, Any]
host, policy, key, entity_in_conn_str, token, token_expiry = _parse_conn_str(conn_str)
queue_name = kwargs.get("queue_name")
topic_name = kwargs.get("topic_name")
if not (queue_name or topic_name or entity_in_conn_str):
raise ValueError("Entity name is missing. Please specify `queue_name` or `topic_name`"
" or use a connection string including the entity information.")

if queue_name and topic_name:
raise ValueError("`queue_name` and `topic_name` can not be specified simultaneously.")

entity_in_kwargs = queue_name or topic_name
if entity_in_conn_str and entity_in_kwargs and (entity_in_conn_str != entity_in_kwargs):
raise ServiceBusAuthenticationError(
"Entity names do not match, the entity name in connection string is {};"
" the entity name in parameter is {}.".format(entity_in_conn_str, entity_in_kwargs)
)

kwargs["fully_qualified_namespace"] = host
kwargs["entity_name"] = entity_in_conn_str or entity_in_kwargs
# This has to be defined seperately to support sync vs async credentials.
kwargs["credential"] = cls._create_credential_from_connection_string_parameters(token,
token_expiry,
policy,
key)
return kwargs

@classmethod
def _create_credential_from_connection_string_parameters(cls, token, token_expiry, policy, key):
if token and token_expiry:
return ServiceBusSASTokenCredential(token, token_expiry)
return ServiceBusSharedKeyCredential(policy, key)
KieranBrantnerMagee marked this conversation as resolved.
Show resolved Hide resolved

def __enter__(self):
self._open_with_retry()
return self
Expand Down
40 changes: 32 additions & 8 deletions sdk/servicebus/azure-servicebus/azure/servicebus/_common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import logging
import functools
import platform
from typing import Optional, Dict
import time
from typing import Optional, Dict, Tuple
try:
from urlparse import urlparse
except ImportError:
Expand Down Expand Up @@ -63,11 +64,15 @@ def utc_now():
return datetime.datetime.now(tz=TZ_UTC)


# This parse_conn_str is used for mgmt, the other in base_handler for handlers. Should be unified.
def parse_conn_str(conn_str):
endpoint = None
shared_access_key_name = None
shared_access_key = None
entity_path = None
# type: (str) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
endpoint = ''
shared_access_key_name = None # type: Optional[str]
shared_access_key = None # type: Optional[str]
entity_path = ''
shared_access_signature = None # type: Optional[str]
shared_access_signature_expiry = None # type: Optional[int]
for element in conn_str.split(';'):
key, _, value = element.partition('=')
if key.lower() == 'endpoint':
Expand All @@ -78,9 +83,28 @@ def parse_conn_str(conn_str):
shared_access_key = value
elif key.lower() == 'entitypath':
entity_path = value
if not all([endpoint, shared_access_key_name, shared_access_key]):
raise ValueError("Invalid connection string")
return endpoint, shared_access_key_name, shared_access_key, entity_path
elif key.lower() == "sharedaccesssignature":
shared_access_signature = value
try:
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
# type: ignore
shared_access_signature_expiry = int(shared_access_signature.split('se=')[1].split('&')[0])
except (IndexError, TypeError, ValueError): # Fallback since technically expiry is optional.
# An arbitrary, absurdly large number, since you can't renew.
shared_access_signature_expiry = int(time.time() * 2)
if not (all((endpoint, shared_access_key_name, shared_access_key)) or all((endpoint, shared_access_signature))) \
or all((shared_access_key_name, shared_access_signature)): # this latter clause since we don't accept both
raise ValueError(
"Invalid connection string. Should be in the format: "
"Endpoint=sb://<FQDN>/;SharedAccessKeyName=<KeyName>;SharedAccessKey=<KeyValue>"
"\nWith alternate option of providing SharedAccessSignature instead of SharedAccessKeyName and Key"
)
return (endpoint,
str(shared_access_key_name) if shared_access_key_name else None,
str(shared_access_key) if shared_access_key else None,
entity_path,
str(shared_access_signature) if shared_access_signature else None,
shared_access_signature_expiry)


def build_uri(address, entity):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

import uamqp

from ._base_handler import _parse_conn_str, ServiceBusSharedKeyCredential, BaseHandler
from ._base_handler import (
_parse_conn_str,
ServiceBusSharedKeyCredential,
ServiceBusSASTokenCredential,
BaseHandler)
from ._servicebus_sender import ServiceBusSender
from ._servicebus_receiver import ServiceBusReceiver
from ._servicebus_session_receiver import ServiceBusSessionReceiver
Expand All @@ -33,8 +37,8 @@ class ServiceBusClient(object):
The namespace format is: `<yournamespace>.servicebus.windows.net`.
:param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which
implements a particular interface for getting tokens. It accepts
:class:`ServiceBusSharedKeyCredential<azure.servicebus.ServiceBusSharedKeyCredential>`, or credential objects
generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method.
credential objects generated by the azure-identity library and objects that implement the
`get_token(self, *scopes)` method.
:keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`.
:keyword transport_type: The type of transport protocol that will be used for communicating with
the Service Bus service. Default is `TransportType.Amqp`.
Expand Down Expand Up @@ -153,11 +157,15 @@ def from_connection_string(
:caption: Create a new instance of the ServiceBusClient from connection string.
"""
host, policy, key, entity_in_conn_str = _parse_conn_str(conn_str)
host, policy, key, entity_in_conn_str, token, token_expiry = _parse_conn_str(conn_str)
if token and token_expiry:
credential = ServiceBusSASTokenCredential(token, token_expiry)
elif policy and key:
credential = ServiceBusSharedKeyCredential(policy, key) # type: ignore
return cls(
fully_qualified_namespace=host,
entity_name=entity_in_conn_str or kwargs.pop("entity_name", None),
credential=ServiceBusSharedKeyCredential(policy, key), # type: ignore
credential=credential, # type: ignore
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from uamqp.constants import SenderSettleMode
from uamqp.authentication.common import AMQPAuth

from ._base_handler import BaseHandler, ServiceBusSharedKeyCredential, _convert_connection_string_to_kwargs
from ._base_handler import BaseHandler
from ._common.utils import create_authentication
from ._common.message import PeekedMessage, ReceivedMessage
from ._common.constants import (
Expand Down Expand Up @@ -56,8 +56,8 @@ class ServiceBusReceiver(BaseHandler, ReceiverMixin): # pylint: disable=too-man
The namespace format is: `<yournamespace>.servicebus.windows.net`.
:param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which
implements a particular interface for getting tokens. It accepts
:class:`ServiceBusSharedKeyCredential<azure.servicebus.ServiceBusSharedKeyCredential>`, or credential objects
generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method.
:class: credential objects generated by the azure-identity library and objects that implement the
`get_token(self, *scopes)` method.
:keyword str queue_name: The path of specific Service Bus Queue the client connects to.
:keyword str topic_name: The path of specific Service Bus Topic which contains the Subscription
the client connects to.
Expand Down Expand Up @@ -363,9 +363,8 @@ def from_connection_string(
:caption: Create a new instance of the ServiceBusReceiver from connection string.
"""
constructor_args = _convert_connection_string_to_kwargs(
constructor_args = cls._convert_connection_string_to_kwargs(
conn_str,
ServiceBusSharedKeyCredential,
**kwargs
)
if kwargs.get("queue_name") and kwargs.get("subscription_name"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from uamqp import SendClient, types
from uamqp.authentication.common import AMQPAuth

from ._base_handler import BaseHandler, ServiceBusSharedKeyCredential, _convert_connection_string_to_kwargs
from ._base_handler import BaseHandler
from ._common import mgmt_handlers
from ._common.message import Message, BatchMessage
from .exceptions import (
Expand Down Expand Up @@ -97,8 +97,8 @@ class ServiceBusSender(BaseHandler, SenderMixin):
The namespace format is: `<yournamespace>.servicebus.windows.net`.
:param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which
implements a particular interface for getting tokens. It accepts
:class:`ServiceBusSharedKeyCredential<azure.servicebus.ServiceBusSharedKeyCredential>`, or credential objects
generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method.
:class: credential objects generated by the azure-identity library and objects that implement the
`get_token(self, *scopes)` method.
:keyword str queue_name: The path of specific Service Bus Queue the client connects to.
:keyword str topic_name: The path of specific Service Bus Topic the client connects to.
:keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`.
Expand Down Expand Up @@ -293,9 +293,8 @@ def from_connection_string(
:caption: Create a new instance of the ServiceBusSender from connection string.

"""
constructor_args = _convert_connection_string_to_kwargs(
constructor_args = cls._convert_connection_string_to_kwargs(
conn_str,
ServiceBusSharedKeyCredential,
**kwargs
)
return cls(**constructor_args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class ServiceBusSessionReceiver(ServiceBusReceiver, SessionReceiverMixin):
The namespace format is: `<yournamespace>.servicebus.windows.net`.
:param ~azure.core.credentials.TokenCredential credential: The credential object used for authentication which
implements a particular interface for getting tokens. It accepts
:class:`ServiceBusSharedKeyCredential<azure.servicebus.ServiceBusSharedKeyCredential>`, or credential objects
generated by the azure-identity library and objects that implement the `get_token(self, *scopes)` method.
:class: credential objects generated by the azure-identity library and objects that implement the
`get_token(self, *scopes)` method.
:keyword str queue_name: The path of specific Service Bus Queue the client connects to.
:keyword str topic_name: The path of specific Service Bus Topic which contains the Subscription
the client connects to.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# license information.
# -------------------------------------------------------------------------
from ._async_message import ReceivedMessage
from ._base_handler_async import ServiceBusSharedKeyCredential
from ._servicebus_sender_async import ServiceBusSender
from ._servicebus_receiver_async import ServiceBusReceiver
from ._servicebus_session_receiver_async import ServiceBusSessionReceiver
Expand All @@ -19,6 +18,5 @@
'ServiceBusReceiver',
'ServiceBusSessionReceiver',
'ServiceBusSession',
'ServiceBusSharedKeyCredential',
'AutoLockRenew'
]
Loading