Skip to content

Commit

Permalink
Enable TransportType.AmqpOverWebsocket for ServiceBus (Azure#10012)
Browse files Browse the repository at this point in the history
* Enable TransportType.AmqpOverWebsocket for ServiceBus #4250
* Enables passing TransportType enum directly to constructors with use TransportType.Amqp as default, as well as parsing TransportType out of connection strings.
* Adds sync and async tests for transport_type parsing.

Co-authored-by: Kieran Brantner-Magee <[email protected]>
  • Loading branch information
dennispg and KieranBrantnerMagee authored Mar 31, 2020
1 parent ef48277 commit 9c14796
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 14 deletions.
5 changes: 4 additions & 1 deletion sdk/servicebus/azure-servicebus/azure/servicebus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
AutoLockRenewFailed,
AutoLockRenewTimeout)

from uamqp.constants import TransportType


__all__ = [
'Message',
Expand All @@ -52,4 +54,5 @@
'MessageLockExpired',
'SessionLockExpired',
'AutoLockRenewFailed',
'AutoLockRenewTimeout']
'AutoLockRenewTimeout',
'TransportType']
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import requests
from uamqp import types
from uamqp.constants import TransportType

from azure.servicebus.common import mgmt_handlers, mixins
from azure.servicebus.aio.async_base_handler import BaseHandler
Expand All @@ -36,6 +37,9 @@ class ServiceBusClient(mixins.ServiceBusMixin):
:param str host_base: Optional. Live host base URL. Defaults to Azure URL.
:param str shared_access_key_name: SAS authentication key name.
:param str shared_access_key_value: SAS authentication key value.
:param transport_type: Optional. Underlying transport protocol type (Amqp or AmqpOverWebsocket)
Default value is ~azure.servicebus.TransportType.Amqp
:type transport_type: ~azure.servicebus.TransportType
:param loop: An async event loop.
:param int http_request_timeout: Optional. Timeout for the HTTP request, in seconds.
:param http_request_session: Optional. Session object to use for HTTP requests.
Expand All @@ -53,13 +57,15 @@ class ServiceBusClient(mixins.ServiceBusMixin):

def __init__(self, *, service_namespace=None, host_base=SERVICE_BUS_HOST_BASE,
shared_access_key_name=None, shared_access_key_value=None, loop=None,
transport_type=TransportType.Amqp,
http_request_timeout=DEFAULT_HTTP_TIMEOUT, http_request_session=None, debug=False):

self.loop = loop or get_running_loop()
self.service_namespace = service_namespace
self.host_base = host_base
self.shared_access_key_name = shared_access_key_name
self.shared_access_key_value = shared_access_key_value
self.transport_type = transport_type
self.debug = debug
self.mgmt_client = ServiceBusService(
service_namespace=service_namespace,
Expand All @@ -85,7 +91,7 @@ def from_connection_string(cls, conn_str, *, loop=None, **kwargs):
:caption: Create a ServiceBusClient via a connection string.
"""
address, policy, key, _ = parse_conn_str(conn_str)
address, policy, key, _, transport_type = parse_conn_str(conn_str)
parsed_namespace = urlparse(address)
namespace, _, base = parsed_namespace.hostname.partition('.')
return cls(
Expand All @@ -94,6 +100,7 @@ def from_connection_string(cls, conn_str, *, loop=None, **kwargs):
shared_access_key_value=key,
host_base='.' + base,
loop=loop,
transport_type=transport_type,
**kwargs)

def get_queue(self, queue_name):
Expand Down Expand Up @@ -124,6 +131,7 @@ def get_queue(self, queue_name):
self._get_host(), queue,
shared_access_key_name=self.shared_access_key_name,
shared_access_key_value=self.shared_access_key_value,
transport_type=self.transport_type,
mgmt_client=self.mgmt_client,
loop=self.loop,
debug=self.debug)
Expand All @@ -144,6 +152,7 @@ def list_queues(self):
self._get_host(), queue,
shared_access_key_name=self.shared_access_key_name,
shared_access_key_value=self.shared_access_key_value,
transport_type=self.transport_type,
mgmt_client=self.mgmt_client,
loop=self.loop,
debug=self.debug))
Expand Down Expand Up @@ -177,6 +186,7 @@ def get_topic(self, topic_name):
self._get_host(), topic,
shared_access_key_name=self.shared_access_key_name,
shared_access_key_value=self.shared_access_key_value,
transport_type=self.transport_type,
loop=self.loop,
debug=self.debug)

Expand All @@ -196,6 +206,7 @@ def list_topics(self):
self._get_host(), topic,
shared_access_key_name=self.shared_access_key_name,
shared_access_key_value=self.shared_access_key_value,
transport_type=self.transport_type,
loop=self.loop,
debug=self.debug))
return topic_clients
Expand Down Expand Up @@ -230,6 +241,7 @@ def get_subscription(self, topic_name, subscription_name):
self._get_host(), topic_name, subscription,
shared_access_key_name=self.shared_access_key_name,
shared_access_key_value=self.shared_access_key_value,
transport_type=self.transport_type,
loop=self.loop,
debug=self.debug)

Expand All @@ -254,6 +266,7 @@ def list_subscriptions(self, topic_name):
self._get_host(), topic_name, sub,
shared_access_key_name=self.shared_access_key_name,
shared_access_key_value=self.shared_access_key_value,
transport_type=self.transport_type,
loop=self.loop,
debug=self.debug))
return sub_clients
Expand Down Expand Up @@ -796,11 +809,11 @@ def from_connection_string(cls, conn_str, name, topic=None, **kwargs): # pylint
not included in the connection string.
:type topic: str
"""
address, policy, key, entity = parse_conn_str(conn_str)
address, policy, key, entity, transport_type = parse_conn_str(conn_str)
entity = topic or entity
address = build_uri(address, entity)
address += "/Subscriptions/" + name
return cls(address, name, shared_access_key_name=policy, shared_access_key_value=key, **kwargs)
return cls(address, name, shared_access_key_name=policy, shared_access_key_value=key, transport_type=transport_type, **kwargs)

@classmethod
def from_entity(cls, address, topic, entity, **kwargs): # pylint: disable=arguments-differ
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from urllib.parse import unquote_plus

from uamqp import Source
from uamqp.constants import TransportType

import azure.common
import azure.servicebus
Expand Down Expand Up @@ -288,6 +289,8 @@ def __init__(self, address, name, shared_access_key_name=None,
'key_name': shared_access_key_name,
'shared_access_key': shared_access_key_value}

self.auth_config['transport_type'] = kwargs.get('transport_type') or TransportType.Amqp

self.mgmt_client = kwargs.get('mgmt_client') or ServiceBusService(
service_namespace=namespace,
shared_access_key_name=shared_access_key_name,
Expand All @@ -312,11 +315,11 @@ def from_connection_string(cls, conn_str, name=None, **kwargs):
:param name: The name of the entity, if the 'EntityName' property is
not included in the connection string.
"""
address, policy, key, entity = parse_conn_str(conn_str)
address, policy, key, entity, transport_type = parse_conn_str(conn_str)
entity = name or entity
address = build_uri(address, entity)
name = address.split('/')[-1]
return cls(address, name, shared_access_key_name=policy, shared_access_key_value=key, **kwargs)
return cls(address, name, shared_access_key_name=policy, shared_access_key_value=key, transport_type=transport_type, **kwargs)

def _get_entity(self):
raise NotImplementedError("Must be implemented by child class.")
Expand Down
18 changes: 14 additions & 4 deletions sdk/servicebus/azure-servicebus/azure/servicebus/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from urllib.parse import urlparse
from concurrent.futures import ThreadPoolExecutor

from uamqp.constants import TransportType

from azure.servicebus.common.errors import AutoLockRenewFailed, AutoLockRenewTimeout
from azure.servicebus import __version__ as sdk_version

Expand All @@ -30,14 +32,14 @@ def get_running_loop():
try:
loop = asyncio._get_running_loop() # pylint: disable=protected-access
except AttributeError:
logger.warning('This version of Python is deprecated, please upgrade to >= v3.5.3')
_log.warning('This version of Python is deprecated, please upgrade to >= v3.5.3')
if loop is None:
logger.warning('No running event loop')
_log.warning('No running event loop')
loop = asyncio.get_event_loop()
return loop
except RuntimeError:
# For backwards compatibility, create new event loop
logger.warning('No running event loop')
_log.warning('No running event loop')
return asyncio.get_event_loop()


Expand All @@ -46,6 +48,7 @@ def parse_conn_str(conn_str):
shared_access_key_name = None
shared_access_key = None
entity_path = None
transport_type = TransportType.Amqp
for element in conn_str.split(';'):
key, _, value = element.partition('=')
if key.lower() == 'endpoint':
Expand All @@ -56,9 +59,16 @@ def parse_conn_str(conn_str):
shared_access_key = value
elif key.lower() == 'entitypath':
entity_path = value
elif key.lower() == 'transporttype':
if value.lower() == "amqpoverwebsocket":
transport_type = TransportType.AmqpOverWebsocket
elif value.lower() == "amqp":
transport_type = TransportType.Amqp
else:
raise ValueError("Invalid value for TransportType in connection string")
if not all([endpoint, shared_access_key_name, shared_access_key]):
raise ValueError("Invalid connection string")
return endpoint, shared_access_key_name, shared_access_key, entity_path
return endpoint, shared_access_key_name, shared_access_key, entity_path, transport_type


def build_uri(address, entity):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from urllib.parse import urlparse

from uamqp import types
from uamqp.constants import TransportType

from azure.servicebus.common import mgmt_handlers, mixins
from azure.servicebus.common.constants import (
Expand All @@ -39,6 +40,9 @@ class ServiceBusClient(mixins.ServiceBusMixin):
:param str host_base: Optional. Live host base URL. Defaults to Public Azure.
:param str shared_access_key_name: SAS authentication key name.
:param str shared_access_key_value: SAS authentication key value.
:param transport_type: Optional. Underlying transport protocol type (Amqp or AmqpOverWebsocket)
Default value is ~azure.servicebus.TransportType.Amqp
:type transport_type: ~azure.servicebus.TransportType
:param int http_request_timeout: Optional. Timeout for the HTTP request, in seconds.
Default value is 65 seconds.
:param http_request_session: Optional. Session object to use for HTTP requests.
Expand All @@ -57,12 +61,14 @@ class ServiceBusClient(mixins.ServiceBusMixin):

def __init__(self, service_namespace=None, host_base=SERVICE_BUS_HOST_BASE,
shared_access_key_name=None, shared_access_key_value=None,
transport_type=TransportType.Amqp,
http_request_timeout=DEFAULT_HTTP_TIMEOUT, http_request_session=None, debug=False):

self.service_namespace = service_namespace
self.host_base = host_base
self.shared_access_key_name = shared_access_key_name
self.shared_access_key_value = shared_access_key_value
self.transport_type = transport_type
self.debug = debug
self.mgmt_client = ServiceBusService(
service_namespace=service_namespace,
Expand All @@ -88,13 +94,14 @@ def from_connection_string(cls, conn_str, **kwargs):
:caption: Create a ServiceBusClient via a connection string.
"""
address, policy, key, _ = parse_conn_str(conn_str)
address, policy, key, _, transport_type = parse_conn_str(conn_str)
parsed_namespace = urlparse(address)
namespace, _, base = parsed_namespace.hostname.partition('.')
return cls(
namespace,
shared_access_key_name=policy,
shared_access_key_value=key,
transport_type=transport_type,
host_base='.' + base,
**kwargs)

Expand Down Expand Up @@ -129,6 +136,7 @@ def get_queue(self, queue_name):
self._get_host(), queue,
shared_access_key_name=self.shared_access_key_name,
shared_access_key_value=self.shared_access_key_value,
transport_type=self.transport_type,
mgmt_client=self.mgmt_client,
debug=self.debug)

Expand Down Expand Up @@ -157,6 +165,7 @@ def list_queues(self):
self._get_host(), queue,
shared_access_key_name=self.shared_access_key_name,
shared_access_key_value=self.shared_access_key_value,
transport_type=self.transport_type,
mgmt_client=self.mgmt_client,
debug=self.debug))
return queue_clients
Expand Down Expand Up @@ -189,6 +198,7 @@ def get_topic(self, topic_name):
self._get_host(), topic,
shared_access_key_name=self.shared_access_key_name,
shared_access_key_value=self.shared_access_key_value,
transport_type=self.transport_type,
debug=self.debug)

def list_topics(self):
Expand Down Expand Up @@ -216,6 +226,7 @@ def list_topics(self):
self._get_host(), topic,
shared_access_key_name=self.shared_access_key_name,
shared_access_key_value=self.shared_access_key_value,
transport_type=self.transport_type,
debug=self.debug))
return topic_clients

Expand Down Expand Up @@ -249,6 +260,7 @@ def get_subscription(self, topic_name, subscription_name):
self._get_host(), topic_name, subscription,
shared_access_key_name=self.shared_access_key_name,
shared_access_key_value=self.shared_access_key_value,
transport_type=self.transport_type,
debug=self.debug)

def list_subscriptions(self, topic_name):
Expand Down Expand Up @@ -281,6 +293,7 @@ def list_subscriptions(self, topic_name):
self._get_host(), topic_name, sub,
shared_access_key_name=self.shared_access_key_name,
shared_access_key_value=self.shared_access_key_value,
transport_type=self.transport_type,
debug=self.debug))
return sub_clients

Expand Down Expand Up @@ -796,11 +809,11 @@ def from_connection_string(cls, conn_str, name, topic=None, **kwargs): # pylint
not included in the connection string.
:type topic: str
"""
address, policy, key, entity = parse_conn_str(conn_str)
address, policy, key, entity, transport_type = parse_conn_str(conn_str)
entity = topic or entity
address = build_uri(address, entity)
address += "/Subscriptions/" + name
return cls(address, name, shared_access_key_name=policy, shared_access_key_value=key, **kwargs)
return cls(address, name, shared_access_key_name=policy, shared_access_key_value=key, transport_type=transport_type, **kwargs)

@classmethod
def from_entity(cls, address, topic, entity, **kwargs): # pylint: disable=arguments-differ
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
#--------------------------------------------------------------------------

import pytest

from azure.servicebus.aio import (
ServiceBusClient)
from uamqp.constants import TransportType
from devtools_testutils import AzureMgmtTestCase, RandomNameResourceGroupPreparer
from servicebus_preparer import (
ServiceBusNamespacePreparer
)

class ServiceBusClientAsyncTests(AzureMgmtTestCase):

@pytest.mark.liveTest
@pytest.mark.live_test_only
@RandomNameResourceGroupPreparer(name_prefix='servicebustest')
@ServiceBusNamespacePreparer(name_prefix='servicebustest')
def test_servicebusclient_from_conn_str_amqpoverwebsocket_async(self, servicebus_namespace_connection_string, **kwargs):
sb_client = ServiceBusClient.from_connection_string(servicebus_namespace_connection_string)
assert sb_client.transport_type == TransportType.Amqp

websocket_sb_client = ServiceBusClient.from_connection_string(servicebus_namespace_connection_string + ';TransportType=AmqpOverWebsocket')
assert websocket_sb_client.transport_type == TransportType.AmqpOverWebsocket
15 changes: 14 additions & 1 deletion sdk/servicebus/azure-servicebus/tests/test_sb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ServiceBusAuthorizationError,
ServiceBusResourceNotFound
)
from uamqp.constants import TransportType
from devtools_testutils import AzureMgmtTestCase, RandomNameResourceGroupPreparer
from servicebus_preparer import (
ServiceBusNamespacePreparer,
Expand Down Expand Up @@ -176,4 +177,16 @@ def test_sb_client_incorrect_queue_conn_str(self, servicebus_queue_authorization

client = ServiceBusClient.from_connection_string(servicebus_queue_authorization_rule_connection_string)
with pytest.raises(AzureHttpError):
client.get_queue(wrong_queue.name)
client.get_queue(wrong_queue.name)


@pytest.mark.liveTest
@pytest.mark.live_test_only
@RandomNameResourceGroupPreparer(name_prefix='servicebustest')
@ServiceBusNamespacePreparer(name_prefix='servicebustest')
def test_servicebusclient_from_conn_str_amqpoverwebsocket(self, servicebus_namespace_connection_string, **kwargs):
sb_client = ServiceBusClient.from_connection_string(servicebus_namespace_connection_string)
assert sb_client.transport_type == TransportType.Amqp

websocket_sb_client = ServiceBusClient.from_connection_string(servicebus_namespace_connection_string + ';TransportType=AmqpOverWebsocket')
assert websocket_sb_client.transport_type == TransportType.AmqpOverWebsocket

0 comments on commit 9c14796

Please sign in to comment.