From c546a02247e1ea1f9dc6e7faedabf81cd4678876 Mon Sep 17 00:00:00 2001 From: swathipil Date: Mon, 1 Aug 2022 13:09:51 -0700 Subject: [PATCH 01/20] add shared connection back into conn manager --- .../azure/eventhub/_connection_manager.py | 86 +++++++++++++++++-- 1 file changed, 79 insertions(+), 7 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py index 623a25ece678..0613390f9bda 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py @@ -4,20 +4,24 @@ # -------------------------------------------------------------------------------------------- from typing import TYPE_CHECKING +from threading import Lock +from enum import Enum -from uamqp import Connection +from uamqp import c_uamqp, Connection as uamqp_Connection +from ._constants import TransportType if TYPE_CHECKING: - from uamqp.authentication import JWTTokenAuth - try: from typing_extensions import Protocol except ImportError: Protocol = object # type: ignore + from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth + class ConnectionManager(Protocol): - def get_connection(self, host, auth): - # type: (str, 'JWTTokenAuth') -> Connection + def get_connection( + self, host: str, auth: uamqp_JWTTokenAuth + ) -> uamqp_Connection: pass def close_connection(self): @@ -27,12 +31,77 @@ def reset_connection_if_broken(self): pass +class _ConnectionMode(Enum): + ShareConnection = 1 + SeparateConnection = 2 + + +class _SharedConnectionManager(object): # pylint:disable=too-many-instance-attributes + def __init__(self, **kwargs): + self._lock = Lock() + self._conn: uamqp_Connection = None + + self._lock = Lock() + self._conn = None # type: uamqp_Connection + + self._container_id = kwargs.get("container_id") + self._debug = kwargs.get("debug") + self._error_policy = kwargs.get("error_policy") + self._properties = kwargs.get("properties") + self._encoding = kwargs.get("encoding") or "UTF-8" + self._transport_type = kwargs.get("transport_type") or TransportType.Amqp + self._http_proxy = kwargs.get("http_proxy") + self._max_frame_size = kwargs.get("max_frame_size") + self._channel_max = kwargs.get("channel_max") + self._idle_timeout = kwargs.get("idle_timeout") + self._remote_idle_timeout_empty_frame_send_ratio = kwargs.get( + "remote_idle_timeout_empty_frame_send_ratio" + ) + + def get_connection(self, host, auth): + # type: (str, uamqp_JWTTokenAuth) -> uamqp_Connection + with self._lock: + if self._conn is None: + self._conn = uamqp_Connection( + host, + auth, + container_id=self._container_id, + max_frame_size=self._max_frame_size, + channel_max=self._channel_max, + idle_timeout=self._idle_timeout, + properties=self._properties, + remote_idle_timeout_empty_frame_send_ratio=self._remote_idle_timeout_empty_frame_send_ratio, + error_policy=self._error_policy, + debug=self._debug, + encoding=self._encoding, + ) + return self._conn + + def close_connection(self): + # type: () -> None + with self._lock: + if self._conn: + self._conn.destroy() + self._conn = None + + def reset_connection_if_broken(self): + # type: () -> None + with self._lock: + if self._conn and self._conn._state in ( # pylint:disable=protected-access + c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member + c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member + c_uamqp.ConnectionState.DISCARDING, # pylint:disable=c-extension-no-member + c_uamqp.ConnectionState.END, # pylint:disable=c-extension-no-member + ): + self._conn = None + + class _SeparateConnectionManager(object): def __init__(self, **kwargs): pass - def get_connection(self, host, auth): # pylint:disable=unused-argument, no-self-use - # type: (str, JWTTokenAuth) -> None + def get_connection(self, endpoint): # pylint:disable=unused-argument, no-self-use + # type: (str) -> None return None def close_connection(self): @@ -46,4 +115,7 @@ def reset_connection_if_broken(self): def get_connection_manager(**kwargs): # type: (...) -> 'ConnectionManager' + connection_mode = kwargs.get("connection_mode", _ConnectionMode.SeparateConnection) + if connection_mode == _ConnectionMode.ShareConnection: + return _SharedConnectionManager(**kwargs) return _SeparateConnectionManager(**kwargs) From a5f0a5502c4b30dadc7bdfdc7f6e108dad4f1826 Mon Sep 17 00:00:00 2001 From: swathipil Date: Mon, 1 Aug 2022 13:10:53 -0700 Subject: [PATCH 02/20] add uamqp switch changes --- .../azure-eventhub/azure/eventhub/__init__.py | 4 +- .../_buffered_producer/_buffered_producer.py | 14 +- .../_buffered_producer_dispatcher.py | 7 +- .../azure/eventhub/_client_base.py | 188 +++--- .../azure-eventhub/azure/eventhub/_common.py | 257 ++++---- .../azure/eventhub/_configuration.py | 14 +- .../azure/eventhub/_constants.py | 31 +- .../azure/eventhub/_consumer.py | 141 ++--- .../azure/eventhub/_consumer_client.py | 10 +- .../azure/eventhub/_producer.py | 186 +++--- .../azure/eventhub/_producer_client.py | 18 +- .../azure/eventhub/_transport/__init__.py | 4 + .../azure/eventhub/_transport/_base.py | 223 +++++++ .../eventhub/_transport/_uamqp_transport.py | 552 ++++++++++++++++++ .../azure-eventhub/azure/eventhub/_utils.py | 91 +-- .../azure-eventhub/azure/eventhub/_version.py | 2 +- .../azure/eventhub/amqp/_amqp_message.py | 245 +++----- .../azure/eventhub/amqp/_amqp_utils.py | 27 + .../azure/eventhub/amqp/_constants.py | 9 - .../azure/eventhub/exceptions.py | 90 +-- 20 files changed, 1399 insertions(+), 714 deletions(-) create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_transport/__init__.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py b/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py index c2a457b2726e..6645bf9ea577 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py @@ -2,12 +2,12 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from uamqp import constants from ._common import EventData, EventDataBatch from ._version import VERSION __version__ = VERSION +from ._constants import TransportType from ._producer_client import EventHubProducerClient from ._consumer_client import EventHubConsumerClient from ._client_base import EventHubSharedKeyCredential @@ -19,8 +19,6 @@ EventHubConnectionStringProperties, ) -TransportType = constants.TransportType - __all__ = [ "EventData", "EventDataBatch", diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py index ff18a87921bd..7baeaf085d3e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import time import queue import logging @@ -14,6 +15,7 @@ from ..exceptions import OperationTimeoutError if TYPE_CHECKING: + from .._transport._base import AmqpTransport from .._producer_client import SendEventTypes _LOGGER = logging.getLogger(__name__) @@ -31,7 +33,8 @@ def __init__( executor: ThreadPoolExecutor, *, max_wait_time: float = 1, - max_buffer_length: int + max_buffer_length: int, + amqp_transport: AmqpTransport ): self._buffered_queue: queue.Queue = queue.Queue() self._max_buffer_len = max_buffer_length @@ -47,11 +50,12 @@ def __init__( self._cur_batch: Optional[EventDataBatch] = None self._max_message_size_on_link = max_message_size_on_link self._check_max_wait_time_future = None + self._amqp_transport = amqp_transport self.partition_id = partition_id def start(self): with self._lock: - self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) self._running = True if self._max_wait_time: self._last_send_time = time.time() @@ -111,11 +115,11 @@ def put_events(self, events, timeout_time=None): self._buffered_queue.put(self._cur_batch) self._buffered_queue.put(events) # create a new batch for incoming events - self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) except ValueError: # add single event exceeds the cur batch size, create new batch self._buffered_queue.put(self._cur_batch) - self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) self._cur_batch.add(events) self._cur_buffered_len += new_events_len @@ -182,7 +186,7 @@ def flush(self, timeout_time=None, raise_error=True): break # after finishing flushing, reset cur batch and put it into the buffer self._last_send_time = time.time() - self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) _LOGGER.info("Partition %r finished flushing.", self.partition_id) def check_max_wait_time_worker(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py index 3c112de18d18..01b300560dc0 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import logging from threading import Lock from concurrent.futures import ThreadPoolExecutor @@ -14,6 +15,7 @@ if TYPE_CHECKING: from .._producer_client import SendEventTypes + from .._transport._base import AmqpTransport _LOGGER = logging.getLogger(__name__) @@ -31,7 +33,8 @@ def __init__( *, max_buffer_length: int = 1500, max_wait_time: float = 1, - executor: Optional[Union[ThreadPoolExecutor, int]] = None + executor: Optional[Union[ThreadPoolExecutor, int]] = None, + amqp_transport: AmqpTransport ): self._buffered_producers: Dict[str, BufferedProducer] = {} self._partition_ids: List[str] = partitions @@ -45,6 +48,7 @@ def __init__( self._max_wait_time = max_wait_time self._max_buffer_length = max_buffer_length self._existing_executor = False + self._amqp_transport = amqp_transport if not executor: self._executor = ThreadPoolExecutor() @@ -86,6 +90,7 @@ def enqueue_events( executor=self._executor, max_wait_time=self._max_wait_time, max_buffer_length=self._max_buffer_length, + amqp_transport=self._amqp_transport ) buffered_producer.start() self._buffered_producers[pid] = buffered_producer diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index 98cc19b0bfa7..716425eda6d0 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations import logging import uuid @@ -11,15 +11,9 @@ import collections from typing import Any, Dict, Tuple, List, Optional, TYPE_CHECKING, cast, Union from datetime import timedelta - -try: - from urlparse import urlparse - from urllib import quote_plus # type: ignore -except ImportError: - from urllib.parse import urlparse, quote_plus - -from uamqp import AMQPClient, Message, authentication, constants, errors, compat, utils +from urllib.parse import urlparse, quote_plus import six + from azure.core.credentials import ( AccessToken, AzureSasCredential, @@ -29,19 +23,33 @@ from azure.core.pipeline.policies import RetryMode -from .exceptions import _handle_exception, ClientClosedError, ConnectError +from uamqp import utils +from ._transport._uamqp_transport import UamqpTransport +from .exceptions import ClientClosedError from ._configuration import Configuration from ._utils import utc_from_timestamp, parse_sas_credential from ._connection_manager import get_connection_manager from ._constants import ( CONTAINER_PREFIX, JWT_TOKEN_SCOPE, - MGMT_OPERATION, - MGMT_PARTITION_OPERATION, + READ_OPERATION, MGMT_STATUS_CODE, MGMT_STATUS_DESC, + MGMT_OPERATION, + MGMT_PARTITION_OPERATION, ) +if TYPE_CHECKING: + from azure.core.credentials import TokenCredential + from uamqp import Message as uamqp_Message + from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth + + CredentialTypes = Union[ + AzureSasCredential, + AzureNamedKeyCredential, + EventHubSharedKeyCredential, + TokenCredential, + ] _LOGGER = logging.getLogger(__name__) _Address = collections.namedtuple("_Address", "hostname path") @@ -186,6 +194,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type: (str, Any) -> AccessToken if not scopes: raise ValueError("No token scope provided.") + return _generate_sas_token(scopes[0], self.policy, self.key) @@ -274,8 +283,16 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument class ClientBase(object): # pylint:disable=too-many-instance-attributes - def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwargs): - # type: (str, str, CredentialTypes, Any) -> None + def __init__( + self, + fully_qualified_namespace: str, + eventhub_name: str, + credential: CredentialTypes, + **kwargs: Any, + ) -> None: + self._uamqp_transport = True + self._amqp_transport = UamqpTransport + self.eventhub_name = eventhub_name if not eventhub_name: raise ValueError("The eventhub name can not be None or empty.") @@ -285,16 +302,17 @@ def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwarg if isinstance(credential, AzureSasCredential): self._credential = EventhubAzureSasTokenCredential(credential) elif isinstance(credential, AzureNamedKeyCredential): - self._credential = EventhubAzureNamedKeyTokenCredential(credential) # type: ignore + self._credential = EventhubAzureNamedKeyTokenCredential(credential) else: self._credential = credential # type: ignore self._keep_alive = kwargs.get("keep_alive", 30) self._auto_reconnect = kwargs.get("auto_reconnect", True) - self._mgmt_target = "amqps://{}/{}".format( - self._address.hostname, self.eventhub_name + self._auth_uri = f"sb://{self._address.hostname}{self._address.path}" + self._config = Configuration( + uamqp_transport=self._uamqp_transport, + hostname=self._address.hostname, + **kwargs, ) - self._auth_uri = "sb://{}{}".format(self._address.hostname, self._address.path) - self._config = Configuration(**kwargs) self._debug = self._config.network_tracing self._conn_manager = get_connection_manager(**kwargs) self._idle_timeout = kwargs.get("idle_timeout", None) @@ -313,11 +331,10 @@ def _from_connection_string(conn_str, **kwargs): kwargs["credential"] = EventHubSharedKeyCredential(policy, key) return kwargs - def _create_auth(self): - # type: () -> authentication.JWTTokenAuth + def _create_auth(self) -> uamqp_JWTTokenAuth: """ - Create an ~uamqp.authentication.SASTokenAuth instance to authenticate - the session. + Create an ~uamqp.authentication.SASTokenAuth instance + to authenticate the session. """ try: # ignore mypy's warning because token_type is Optional @@ -325,32 +342,19 @@ def _create_auth(self): except AttributeError: token_type = b"jwt" if token_type == b"servicebus.windows.net:sastoken": - auth = authentication.JWTTokenAuth( - self._auth_uri, + return self._amqp_transport.create_token_auth( self._auth_uri, functools.partial(self._credential.get_token, self._auth_uri), token_type=token_type, - timeout=self._config.auth_timeout, - http_proxy=self._config.http_proxy, - transport_type=self._config.transport_type, - custom_endpoint_hostname=self._config.custom_endpoint_hostname, - port=self._config.connection_port, - verify=self._config.connection_verify, + config=self._config, + update_token=True, ) - auth.update_token() - return auth - return authentication.JWTTokenAuth( - self._auth_uri, + return self._amqp_transport.create_token_auth( self._auth_uri, functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE), token_type=token_type, - timeout=self._config.auth_timeout, - http_proxy=self._config.http_proxy, - transport_type=self._config.transport_type, - custom_endpoint_hostname=self._config.custom_endpoint_hostname, - port=self._config.connection_port, - verify=self._config.connection_verify, - refresh_window=300, + config=self._config, + update_token=False, ) def _close_connection(self): @@ -385,26 +389,29 @@ def _backoff( ) raise last_exception - def _management_request(self, mgmt_msg, op_type): - # type: (Message, bytes) -> Any + def _management_request( + self, mgmt_msg: uamqp_Message, op_type: bytes + ) -> Any: # pylint:disable=assignment-from-none retried_times = 0 last_exception = None while retried_times <= self._config.max_retries: mgmt_auth = self._create_auth() - mgmt_client = AMQPClient( - self._mgmt_target, auth=mgmt_auth, debug=self._config.network_tracing + mgmt_client = self._amqp_transport.create_mgmt_client( + self._address, mgmt_auth=mgmt_auth, config=self._config ) try: - conn = self._conn_manager.get_connection( # pylint:disable=assignment-from-none - self._address.hostname, mgmt_auth - ) - mgmt_client.open(connection=conn) - mgmt_msg.application_properties["security_token"] = mgmt_auth.token - response = mgmt_client.mgmt_request( + mgmt_client.open() + while not mgmt_client.client_ready(): + time.sleep(0.05) + mgmt_msg.application_properties[ + "security_token" + ] = self._amqp_transport.get_updated_token(mgmt_auth) + response = self._amqp_transport.mgmt_client_request( + mgmt_client, mgmt_msg, - constants.READ_OPERATION, - op_type=op_type, + operation=READ_OPERATION, + operation_type=op_type, status_code_field=MGMT_STATUS_CODE, description_fields=MGMT_STATUS_DESC, ) @@ -417,24 +424,25 @@ def _management_request(self, mgmt_msg, op_type): if status_code < 400: return response if status_code in [401]: - raise errors.AuthenticationException( - "Management authentication failed. Status code: {}, Description: {!r}".format( - status_code, description - ) + raise self._amqp_transport.get_error( + self._amqp_transport.AUTH_EXCEPTION, + f"Management authentication failed. Status code: {status_code}, Description: {description!r}" ) - if status_code in [404]: - raise ConnectError( - "Management connection failed. Status code: {}, Description: {!r}".format( - status_code, description - ) - ) - raise errors.AMQPConnectionError( - "Management request error. Status code: {}, Description: {!r}".format( - status_code, description + if status_code in [ + 404 + ]: + return self._amqp_transport.get_error( + self._amqp_transport.CONNECTION_ERROR, + f"Management connection failed. Status code: {status_code}, Description: {description!r}" ) + return self._amqp_transport.get_error( + self._amqp_transport.AMQP_CONNECTION_ERROR, + f"Management request error. Status code: {status_code}, Description: {description!r}" ) except Exception as exception: # pylint: disable=broad-except - last_exception = _handle_exception(exception, self) + last_exception = self._amqp_transport._handle_exception( + exception, self + ) # pylint: disable=protected-access self._backoff( retried_times=retried_times, last_exception=last_exception ) @@ -453,12 +461,13 @@ def _add_span_request_attributes(self, span): span.add_attribute("message_bus.destination", self._address.path) span.add_attribute("peer.address", self._address.hostname) - def _get_eventhub_properties(self): - # type:() -> Dict[str, Any] - mgmt_msg = Message(application_properties={"name": self.eventhub_name}) + def _get_eventhub_properties(self) -> Dict[str, Any]: + mgmt_msg = self._amqp_transport.MESSAGE( + application_properties={"name": self.eventhub_name} + ) response = self._management_request(mgmt_msg, op_type=MGMT_OPERATION) output = {} - eh_info = response.get_data() # type: Dict[bytes, Any] + eh_info: Dict[bytes, Any] = response.value if eh_info: output["eventhub_name"] = eh_info[b"name"].decode("utf-8") output["created_at"] = utc_from_timestamp( @@ -475,14 +484,14 @@ def _get_partition_ids(self): def _get_partition_properties(self, partition_id): # type:(str) -> Dict[str, Any] - mgmt_msg = Message( + mgmt_msg = self._amqp_transport.MESSAGE( application_properties={ "name": self.eventhub_name, "partition": partition_id, } ) response = self._management_request(mgmt_msg, op_type=MGMT_PARTITION_OPERATION) - partition_info = response.get_data() # type: Dict[bytes, Any] + partition_info = response.value # type: Dict[bytes, Any] output = {} if partition_info: output["eventhub_name"] = partition_info[b"name"].decode("utf-8") @@ -520,9 +529,7 @@ def _create_handler(self, auth): def _check_closed(self): if self.closed: raise ClientClosedError( - "{} has been closed. Please create a new one to handle event data.".format( - self._name - ) + f"{self._name} has been closed. Please create a new one to handle event data." ) def _open(self): @@ -533,17 +540,13 @@ def _open(self): self._handler.close() auth = self._client._create_auth() self._create_handler(auth) - self._handler.open( - connection=self._client._conn_manager.get_connection( - self._client._address.hostname, auth - ) # pylint: disable=protected-access - ) + self._handler.open() while not self._handler.client_ready(): time.sleep(0.05) self._max_message_size_on_link = ( - self._handler.message_handler._link.peer_max_message_size - or constants.MAX_MESSAGE_LENGTH_BYTES - ) # pylint: disable=protected-access + self._amqp_transport.get_remote_max_message_size(self._handler) + or self._amqp_transport.MAX_FRAME_SIZE_BYTES + ) self.running = True def _close_handler(self): @@ -556,9 +559,16 @@ def _close_connection(self): self._client._conn_manager.reset_connection_if_broken() # pylint: disable=protected-access def _handle_exception(self, exception): - if not self.running and isinstance(exception, compat.TimeoutException): - exception = errors.AuthenticationException("Authorization timeout.") - return _handle_exception(exception, self) + if not self.running and isinstance( + exception, self._amqp_transport.TIMEOUT_EXCEPTION + ): + exception = self._amqp_transport.get_error( + self._amqp_transport.AUTH_EXCEPTION, + "Authorization timeout." + ) + return self._amqp_transport._handle_exception( # pylint: disable=protected-access + exception, self + ) def _do_retryable_operation(self, operation, timeout=None, **kwargs): # pylint:disable=protected-access @@ -576,7 +586,7 @@ def _do_retryable_operation(self, operation, timeout=None, **kwargs): return operation( timeout_time=timeout_time, last_exception=last_exception, - **kwargs + **kwargs, ) return operation() except Exception as exception: # pylint:disable=broad-except diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 1ec58f126b52..b302b9c7727e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -2,9 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations import json +import datetime import logging import uuid from typing import ( @@ -22,10 +23,7 @@ import six -from uamqp import BatchMessage, Message, constants - from ._utils import ( - set_message_partition_key, trace_message, utc_from_timestamp, transform_outbound_single_message, @@ -35,7 +33,6 @@ PROP_SEQ_NUMBER, PROP_OFFSET, PROP_PARTITION_KEY, - PROP_PARTITION_KEY_AMQP_SYMBOL, PROP_TIMESTAMP, PROP_ABSOLUTE_EXPIRY_TIME, PROP_CONTENT_ENCODING, @@ -57,9 +54,11 @@ AmqpMessageHeader, AmqpMessageProperties, ) +from ._transport._uamqp_transport import UamqpTransport if TYPE_CHECKING: - import datetime + from uamqp import Message as uamqp_Message, BatchMessage as uamqp_BatchMessage + from ._transport._base import AmqpTransport MessageContent = TypedDict("MessageContent", {"content": bytes, "content_type": str}) PrimitiveTypes = Optional[ @@ -127,62 +126,59 @@ def __init__( self._raw_amqp_message = AmqpAnnotatedMessage( # type: ignore data_body=body, annotations={}, application_properties={} ) - self.message = ( - self._raw_amqp_message._message - ) # pylint:disable=protected-access + # amqp message to be reset right before sending + self._message = UamqpTransport.to_outgoing_amqp_message(self._raw_amqp_message) self._raw_amqp_message.header = AmqpMessageHeader() self._raw_amqp_message.properties = AmqpMessageProperties() self.message_id = None self.content_type = None self.correlation_id = None - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: # pylint: disable=bare-except try: body_str = self.body_as_str() except: body_str = "" - event_repr = "body='{}'".format(body_str) + event_repr = f"body='{body_str}'" try: - event_repr += ", properties={}".format(self.properties) + event_repr += f", properties={self.properties}" except: event_repr += ", properties=" try: - event_repr += ", offset={}".format(self.offset) + event_repr += f", offset={self.offset}" except: event_repr += ", offset=" try: - event_repr += ", sequence_number={}".format(self.sequence_number) + event_repr += f", sequence_number={self.sequence_number}" except: event_repr += ", sequence_number=" try: - event_repr += ", partition_key={!r}".format(self.partition_key) + event_repr += f", partition_key={self.partition_key!r}" except: event_repr += ", partition_key=" try: - event_repr += ", enqueued_time={!r}".format(self.enqueued_time) + event_repr += f", enqueued_time={self.enqueued_time!r}" except: event_repr += ", enqueued_time=" - return "EventData({})".format(event_repr) + return f"EventData({event_repr})" - def __str__(self): - # type: () -> str + def __str__(self) -> str: try: body_str = self.body_as_str() except: # pylint: disable=bare-except body_str = "" - event_str = "{{ body: '{}'".format(body_str) + event_str = f"{{ body: '{body_str}'" try: - event_str += ", properties: {}".format(self.properties) + event_str += f", properties: {self.properties}" if self.offset: - event_str += ", offset: {}".format(self.offset) + event_str += f", offset: {self.offset}" if self.sequence_number: - event_str += ", sequence_number: {}".format(self.sequence_number) + event_str += f", sequence_number: {self.sequence_number}" if self.partition_key: - event_str += ", partition_key={!r}".format(self.partition_key) + event_str += f", partition_key={self.partition_key!r}" if self.enqueued_time: - event_str += ", enqueued_time={!r}".format(self.enqueued_time) + event_str += f", enqueued_time={self.enqueued_time!r}" except: # pylint: disable=bare-except pass event_str += " }" @@ -213,8 +209,11 @@ def from_message_content( # pylint: disable=unused-argument return event_data @classmethod - def _from_message(cls, message, raw_amqp_message=None): - # type: (Message, Optional[AmqpAnnotatedMessage]) -> EventData + def _from_message( + cls, + message: uamqp_Message, + raw_amqp_message: Optional[AmqpAnnotatedMessage] = None, + ) -> EventData: # pylint:disable=protected-access """Internal use only. @@ -225,7 +224,7 @@ def _from_message(cls, message, raw_amqp_message=None): :rtype: ~azure.eventhub.EventData """ event_data = cls(body="") - event_data.message = message + event_data._message = message # pylint: disable=protected-access event_data._raw_amqp_message = ( raw_amqp_message @@ -234,39 +233,32 @@ def _from_message(cls, message, raw_amqp_message=None): ) return event_data - def _encode_message(self): - # type: () -> bytes - # pylint: disable=protected-access - return self._raw_amqp_message._message.encode_message() - - def _decode_non_data_body_as_str(self, encoding="UTF-8"): - # type: (str) -> str + def _decode_non_data_body_as_str(self, encoding: str = "UTF-8") -> str: # pylint: disable=protected-access - body = self.raw_amqp_message._message._body + body = self.raw_amqp_message.body if self.body_type == AmqpMessageBodyType.VALUE: - if not body.data: + if not body: return "" - return str(decode_with_recurse(body.data, encoding)) + return str(decode_with_recurse(body, encoding)) - seq_list = [d for seq_section in body.data for d in seq_section] + seq_list = [d for seq_section in body for d in seq_section] return str(decode_with_recurse(seq_list, encoding)) - def _to_outgoing_message(self): - # type: () -> EventData - self.message = ( - self._raw_amqp_message._to_outgoing_amqp_message() # pylint:disable=protected-access - ) - return self + @property + def message(self) -> uamqp_Message: + return self._message + + @message.setter + def message(self, value: uamqp_Message) -> None: + self._message = value @property - def raw_amqp_message(self): - # type: () -> AmqpAnnotatedMessage + def raw_amqp_message(self) -> AmqpAnnotatedMessage: """Advanced usage only. The internal AMQP message payload that is sent or received.""" return self._raw_amqp_message @property - def sequence_number(self): - # type: () -> Optional[int] + def sequence_number(self) -> Optional[int]: """The sequence number of the event. :rtype: int @@ -274,8 +266,7 @@ def sequence_number(self): return self._raw_amqp_message.annotations.get(PROP_SEQ_NUMBER, None) @property - def offset(self): - # type: () -> Optional[str] + def offset(self) -> Optional[str]: """The offset of the event. :rtype: str @@ -286,8 +277,7 @@ def offset(self): return None @property - def enqueued_time(self): - # type: () -> Optional[datetime.datetime] + def enqueued_time(self) -> Optional[datetime.datetime]: """The enqueued timestamp of the event. :rtype: datetime.datetime @@ -298,20 +288,20 @@ def enqueued_time(self): return None @property - def partition_key(self): - # type: () -> Optional[bytes] + def partition_key(self) -> Optional[bytes]: """The partition key of the event. :rtype: bytes """ - try: - return self._raw_amqp_message.annotations[PROP_PARTITION_KEY_AMQP_SYMBOL] - except KeyError: - return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None) + # TODO: Ask Anna. can we remove the try and just do except? Haven't seen a case where symbol is used to get. + # try: + # return self._raw_amqp_message.annotations[types.AMQPSymbol(PROP_PARTITION_KEY)] + # except KeyError: + # return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None) + return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None) @property - def properties(self): - # type: () -> Dict[Union[str, bytes], Any] + def properties(self) -> Dict[Union[str, bytes], Any]: """Application-defined properties on the event. :rtype: dict @@ -319,7 +309,7 @@ def properties(self): return self._raw_amqp_message.application_properties @properties.setter - def properties(self, value): + def properties(self, value: Dict[Union[str, bytes], Any]): # type: (Dict[Union[str, bytes], Any]) -> None """Application-defined properties on the event. @@ -329,8 +319,7 @@ def properties(self, value): self._raw_amqp_message.application_properties = properties @property - def system_properties(self): - # type: () -> Dict[bytes, Any] + def system_properties(self) -> Dict[bytes, Any]: """Metadata set by the Event Hubs Service associated with the event. An EventData could have some or all of the following meta data depending on the source @@ -368,8 +357,7 @@ def system_properties(self): return self._sys_properties @property - def body(self): - # type: () -> PrimitiveTypes + def body(self) -> PrimitiveTypes: """The body of the Message. The format may vary depending on the body type: For :class:`azure.eventhub.amqp.AmqpMessageBodyType.DATA`, the body could be bytes or Iterable[bytes]. @@ -386,16 +374,14 @@ def body(self): raise ValueError("Event content empty.") @property - def body_type(self): - # type: () -> AmqpMessageBodyType + def body_type(self) -> AmqpMessageBodyType: """The body type of the underlying AMQP message. :rtype: ~azure.eventhub.amqp.AmqpMessageBodyType """ return self._raw_amqp_message.body_type - def body_as_str(self, encoding="UTF-8"): - # type: (str) -> str + def body_as_str(self, encoding: str = "UTF-8") -> str: """The content of the event as a string, if the data is of a compatible type. :param encoding: The encoding to use for decoding event data. @@ -414,12 +400,9 @@ def body_as_str(self, encoding="UTF-8"): try: return cast(bytes, data).decode(encoding) except Exception as e: - raise TypeError( - "Message data is not compatible with string type: {}".format(e) - ) + raise TypeError(f"Message data is not compatible with string type: {e}") - def body_as_json(self, encoding="UTF-8"): - # type: (str) -> Dict[str, Any] + def body_as_json(self, encoding: str = "UTF-8") -> Dict[str, Any]: """The content of the event loaded as a JSON object, if the data is compatible. :param encoding: The encoding to use for decoding event data. @@ -430,11 +413,10 @@ def body_as_json(self, encoding="UTF-8"): try: return json.loads(data_str) except Exception as e: - raise TypeError("Event data is not compatible with JSON type: {}".format(e)) + raise TypeError(f"Event data is not compatible with JSON type: {e}") @property - def content_type(self): - # type: () -> Optional[str] + def content_type(self) -> Optional[str]: """The content type descriptor. Optionally describes the payload of the message, with a descriptor following the format of RFC2045, Section 5, for example "application/json". @@ -448,15 +430,13 @@ def content_type(self): return self._raw_amqp_message.properties.content_type @content_type.setter - def content_type(self, value): - # type: (str) -> None + def content_type(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.content_type = value @property - def correlation_id(self): - # type: () -> Optional[str] + def correlation_id(self) -> Optional[str]: """The correlation identifier. Allows an application to specify a context for the message for the purposes of correlation, for example reflecting the MessageId of a message that is being replied to. @@ -470,15 +450,13 @@ def correlation_id(self): return self._raw_amqp_message.properties.correlation_id @correlation_id.setter - def correlation_id(self, value): - # type: (str) -> None + def correlation_id(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.correlation_id = value @property - def message_id(self): - # type: () -> Optional[str] + def message_id(self) -> Optional[str]: """The id to identify the message. The message identifier is an application-defined value that uniquely identifies the message and its payload. The identifier is a free-form string and can reflect a GUID or an identifier derived from the @@ -494,7 +472,7 @@ def message_id(self): return self._raw_amqp_message.properties.message_id @message_id.setter - def message_id(self, value): + def message_id(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.message_id = value @@ -525,8 +503,18 @@ class EventDataBatch(object): Event Hub decided by the service. """ - def __init__(self, max_size_in_bytes=None, partition_id=None, partition_key=None): - # type: (Optional[int], Optional[str], Optional[Union[str, bytes]]) -> None + def __init__( + self, + max_size_in_bytes: Optional[int] = None, + partition_id: Optional[str] = None, + partition_key: Optional[Union[str, bytes]] = None, + **kwargs, + ) -> None: + # TODO: this changes API, check with Anna if valid - + # If possible, move out message creation to right before sending. + # Might take more time to loop through events and add them all to batch in `send` than in `add` here + self._amqp_transport = kwargs.pop("amqp_transport", UamqpTransport) + if partition_key and not isinstance( partition_key, (six.text_type, six.binary_type) @@ -538,35 +526,51 @@ def __init__(self, max_size_in_bytes=None, partition_id=None, partition_key=None "partition_key to only be string type, they might fail to parse the non-string value." ) - self.max_size_in_bytes = max_size_in_bytes or constants.MAX_MESSAGE_LENGTH_BYTES - self.message = BatchMessage(data=[], multi_messages=False, properties=None) + self.max_size_in_bytes = ( + max_size_in_bytes or self._amqp_transport.MAX_FRAME_SIZE_BYTES + ) + self._message = self._amqp_transport.BATCH_MESSAGE(data=[]) self._partition_id = partition_id self._partition_key = partition_key - set_message_partition_key(self.message, self._partition_key) - self._size = self.message.gather()[0].get_message_encoded_size() + self._message = self._amqp_transport.set_message_partition_key( + self._message, self._partition_key + ) + self._size = self._amqp_transport.get_batch_message_encoded_size(self._message) self._count = 0 - self._internal_events: List[Union[EventData, AmqpAnnotatedMessage]] = [] - - def __repr__(self): - # type: () -> str - batch_repr = "max_size_in_bytes={}, partition_id={}, partition_key={!r}, event_count={}".format( - self.max_size_in_bytes, self._partition_id, self._partition_key, self._count + self._internal_events: List[ + Union[EventData, AmqpAnnotatedMessage] + ] = [] + + def __repr__(self) -> str: + batch_repr = ( + f"max_size_in_bytes={self.max_size_in_bytes}, partition_id={self._partition_id}, " + f"partition_key={self._partition_key!r}, event_count={self._count}" ) - return "EventDataBatch({})".format(batch_repr) + return f"EventDataBatch({batch_repr})" - def __len__(self): + def __len__(self) -> int: return self._count @classmethod - def _from_batch(cls, batch_data, partition_key=None): - # type: (Iterable[EventData], Optional[AnyStr]) -> EventDataBatch + def _from_batch( + cls, + batch_data: Iterable[EventData], + amqp_transport: AmqpTransport, + partition_key: Optional[AnyStr] = None, + ) -> EventDataBatch: outgoing_batch_data = [ - transform_outbound_single_message(m, EventData) for m in batch_data + transform_outbound_single_message( + m, EventData, amqp_transport.to_outgoing_amqp_message + ) + for m in batch_data ] - batch_data_instance = cls(partition_key=partition_key) - for data in outgoing_batch_data: - batch_data_instance.add(data) + batch_data_instance = cls( + partition_key=partition_key, amqp_transport=amqp_transport + ) + + for event_data in outgoing_batch_data: + batch_data_instance.add(event_data) return batch_data_instance def _load_events(self, events): @@ -581,16 +585,22 @@ def _load_events(self, events): ) @property - def size_in_bytes(self): - # type: () -> int + def size_in_bytes(self) -> int: """The combined size of the events in the batch, in bytes. :rtype: int """ return self._size - def add(self, event_data): - # type: (Union[EventData, AmqpAnnotatedMessage]) -> None + @property + def message(self) -> uamqp_BatchMessage: + return self._message + + @message.setter + def message(self, value: uamqp_BatchMessage) -> None: + self._message = value + + def add(self, event_data: Union[EventData, AmqpAnnotatedMessage]) -> None: """Try to add an EventData to the batch. The total size of an added event is the sum of its body, properties, etc. @@ -603,7 +613,9 @@ def add(self, event_data): :raise: :class:`ValueError`, when exceeding the size limit. """ - outgoing_event_data = transform_outbound_single_message(event_data, EventData) + outgoing_event_data = transform_outbound_single_message( + event_data, EventData, self._amqp_transport.to_outgoing_amqp_message + ) if self._partition_key: if ( @@ -614,13 +626,15 @@ def add(self, event_data): "The partition key of event_data does not match the partition key of this batch." ) if not outgoing_event_data.partition_key: - set_message_partition_key( - outgoing_event_data.message, self._partition_key + self._amqp_transport.set_message_partition_key( + outgoing_event_data._message, # pylint: disable=protected-access + self._partition_key, ) trace_message(outgoing_event_data) - event_data_size = outgoing_event_data.message.get_message_encoded_size() - + event_data_size = self._amqp_transport.get_message_encoded_size( + outgoing_event_data._message # pylint: disable=protected-access + ) # For a BatchMessage, if the encoded_message_size of event_data is < 256, then the overhead cost to encode that # message into the BatchMessage would be 5 bytes, if >= 256, it would be 8 bytes. size_after_add = ( @@ -631,14 +645,9 @@ def add(self, event_data): if size_after_add > self.max_size_in_bytes: raise ValueError( - "EventDataBatch has reached its size limit: {}".format( - self.max_size_in_bytes - ) + f"EventDataBatch has reached its size limit: {self.max_size_in_bytes}" ) - self._internal_events.append(event_data) - self.message._body_gen.append( # pylint: disable=protected-access - outgoing_event_data - ) + self._amqp_transport.add_batch(self, outgoing_event_data, event_data) self._size = size_after_add self._count += 1 diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py index 652fcb9bbc04..06d3e3a21c7a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py @@ -3,14 +3,10 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- from typing import Optional, Dict, Any +from urllib.parse import urlparse -try: - from urlparse import urlparse -except ImportError: - from urllib.parse import urlparse - -from uamqp.constants import TransportType, DEFAULT_AMQPS_PORT, DEFAULT_AMQP_WSS_PORT from azure.core.pipeline.policies import RetryMode +from ._constants import TransportType, DEFAULT_AMQPS_PORT, DEFAULT_AMQP_WSS_PORT class Configuration(object): # pylint:disable=too-many-instance-attributes @@ -39,10 +35,14 @@ def __init__(self, **kwargs): self.connection_verify = kwargs.get("connection_verify") # type: Optional[str] self.connection_port = DEFAULT_AMQPS_PORT self.custom_endpoint_hostname = None + self.hostname = kwargs.pop("hostname") + uamqp_transport = kwargs.pop("uamqp_transport") if self.http_proxy or self.transport_type == TransportType.AmqpOverWebsocket: self.transport_type = TransportType.AmqpOverWebsocket self.connection_port = DEFAULT_AMQP_WSS_PORT + if not uamqp_transport: + self.hostname += "/$servicebus/websocket" # custom end point if self.custom_endpoint_address: @@ -53,5 +53,7 @@ def __init__(self, **kwargs): endpoint = urlparse(self.custom_endpoint_address) self.transport_type = TransportType.AmqpOverWebsocket self.custom_endpoint_hostname = endpoint.hostname + if not uamqp_transport: + self.custom_endpoint_address += "/$servicebus/websocket" # in case proxy and custom endpoint are both provided, we default port to 443 if it's not provided self.connection_port = endpoint.port or DEFAULT_AMQP_WSS_PORT diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_constants.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_constants.py index 8c21614a3932..de5659411a84 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_constants.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_constants.py @@ -4,13 +4,12 @@ # -------------------------------------------------------------------------------------------- from __future__ import unicode_literals -from uamqp import types +from enum import Enum PROP_SEQ_NUMBER = b"x-opt-sequence-number" PROP_OFFSET = b"x-opt-offset" PROP_PARTITION_KEY = b"x-opt-partition-key" -PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) PROP_TIMESTAMP = b"x-opt-enqueued-time" PROP_LAST_ENQUEUED_SEQUENCE_NUMBER = b"last_enqueued_sequence_number" PROP_LAST_ENQUEUED_OFFSET = b"last_enqueued_offset" @@ -35,6 +34,7 @@ TIMEOUT_SYMBOL = b"com.microsoft:timeout" RECEIVER_RUNTIME_METRIC_SYMBOL = b"com.microsoft:enable-receiver-runtime-metric" +MAX_MESSAGE_LENGTH_BYTES = 1024 * 1024 MAX_USER_AGENT_LENGTH = 512 ALL_PARTITIONS = "all-partitions" CONTAINER_PREFIX = "eventhub.pysdk-" @@ -45,10 +45,33 @@ MGMT_STATUS_DESC = b"status-description" USER_AGENT_PREFIX = "azsdk-python-eventhubs" -NO_RETRY_ERRORS = ( +NO_RETRY_ERRORS = [ b"com.microsoft:argument-out-of-range", b"com.microsoft:entity-disabled", b"com.microsoft:auth-failed", b"com.microsoft:precondition-failed", b"com.microsoft:argument-error", -) +] + +CUSTOM_CONDITION_BACKOFF = { + b"com.microsoft:server-busy": 4, + b"com.microsoft:timeout": 2, + b"com.microsoft:operation-cancelled": 0, + b"com.microsoft:container-close": 4 +} + + +## all below - previously uamqp +class TransportType(Enum): + """Transport type + The underlying transport protocol type: + Amqp: AMQP over the default TCP transport protocol, it uses port 5671. + AmqpOverWebsocket: Amqp over the Web Sockets transport protocol, it uses + port 443. + """ + Amqp = 1 + AmqpOverWebsocket = 2 + +DEFAULT_AMQPS_PORT = 5671 +DEFAULT_AMQP_WSS_PORT = 443 +READ_OPERATION = b"READ" diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index a2a2f80e3df4..6015bf0186f2 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -2,19 +2,15 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations +from multiprocessing import Event import time import uuid import logging from collections import deque -from typing import TYPE_CHECKING, Callable, Dict, Optional, Any +from typing import TYPE_CHECKING, Callable, Dict, Optional, Any, Deque, Union -import uamqp -from uamqp import types, errors, utils -from uamqp import ReceiveClient, Source, Message - -from .exceptions import _error_handler from ._common import EventData from ._client_base import ConsumerProducerMixin from ._utils import create_properties, event_position_selector @@ -26,7 +22,9 @@ if TYPE_CHECKING: from typing import Deque - from uamqp.authentication import JWTTokenAuth + from uamqp import ReceiveClient as uamqp_ReceiveClient, Message as uamqp_Message + from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth + from ._consumer_client import EventHubConsumerClient @@ -69,8 +67,7 @@ class EventHubConsumer( It is set to `False` by default. """ - def __init__(self, client, source, **kwargs): - # type: (EventHubConsumerClient, str, Any) -> None + def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs: Any) -> None: event_position = kwargs.get("event_position", None) prefetch = kwargs.get("prefetch", 300) owner_level = kwargs.get("owner_level", None) @@ -86,9 +83,10 @@ def __init__(self, client, source, **kwargs): self.stop = False # used by event processor self.handler_ready = False - self._on_event_received = kwargs[ + self._amqp_transport = kwargs.pop("amqp_transport") + self._on_event_received: Callable[[EventData], None] = kwargs[ "on_event_received" - ] # type: Callable[[EventData], None] + ] self._client = client self._source = source self._offset = event_position @@ -97,110 +95,86 @@ def __init__(self, client, source, **kwargs): self._owner_level = owner_level self._keep_alive = keep_alive self._auto_reconnect = auto_reconnect - self._retry_policy = errors.ErrorPolicy( - max_retries=self._client._config.max_retries, - on_error=_error_handler, # pylint:disable=protected-access - ) + self._retry_policy = self._amqp_transport.create_retry_policy(self._client._config) self._reconnect_backoff = 1 - self._link_properties = {} # type: Dict[types.AMQPType, types.AMQPType] + link_properties: Dict[bytes, int] = {} self._error = None self._timeout = 0 - self._idle_timeout = (idle_timeout * 1000) if idle_timeout else None + self._idle_timeout = (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) if idle_timeout else None partition = self._source.split("/")[-1] self._partition = partition - self._name = "EHConsumer-{}-partition{}".format(uuid.uuid4(), partition) + self._name = f"EHConsumer-{uuid.uuid4()}-partition{partition}" if owner_level is not None: - self._link_properties[types.AMQPSymbol(EPOCH_SYMBOL)] = types.AMQPLong( - int(owner_level) - ) + link_properties[EPOCH_SYMBOL] = int(owner_level) link_property_timeout_ms = ( - self._client._config.receive_timeout - or self._timeout # pylint:disable=protected-access - ) * 1000 - self._link_properties[types.AMQPSymbol(TIMEOUT_SYMBOL)] = types.AMQPLong( - int(link_property_timeout_ms) - ) - self._handler = None # type: Optional[ReceiveClient] + self._client._config.receive_timeout or self._timeout # pylint:disable=protected-access + ) * self._amqp_transport.IDLE_TIMEOUT_FACTOR + link_properties[TIMEOUT_SYMBOL] = int(link_property_timeout_ms) + self._link_properties = self._amqp_transport.create_link_properties(link_properties) + self._handler: Optional[uamqp_ReceiveClient] = None self._track_last_enqueued_event_properties = ( track_last_enqueued_event_properties ) - self._message_buffer = deque() # type: Deque[Message] - self._last_received_event = None # type: Optional[EventData] - self._receive_start_time = None # type: Optional[float] - - def _create_handler(self, auth): - # type: (JWTTokenAuth) -> None - source = Source(self._source) - if self._offset is not None: - source.set_filter( - event_position_selector(self._offset, self._offset_inclusive) - ) - desired_capabilities = None - if self._track_last_enqueued_event_properties: - symbol_array = [types.AMQPSymbol(RECEIVER_RUNTIME_METRIC_SYMBOL)] - desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) - - properties = create_properties( - self._client._config.user_agent # pylint:disable=protected-access + self._message_buffer: Deque[uamqp_Message] = deque() + self._last_received_event: Optional[EventData] = None + self._receive_start_time: Optional[float]= None + + def _create_handler(self, auth: uamqp_JWTTokenAuth) -> None: + source = self._amqp_transport.create_source( + self._source, + self._offset, + event_position_selector(self._offset, self._offset_inclusive) ) - self._handler = ReceiveClient( - source, + desired_capabilities = [RECEIVER_RUNTIME_METRIC_SYMBOL] if self._track_last_enqueued_event_properties else None + + self._handler = self._amqp_transport.create_receive_client( + config=self._client._config, # pylint:disable=protected-access + source=source, auth=auth, - debug=self._client._config.network_tracing, # pylint:disable=protected-access - prefetch=self._prefetch, + network_trace=self._client._config.network_tracing, # pylint:disable=protected-access + link_credit=self._prefetch, link_properties=self._link_properties, - timeout=self._timeout, idle_timeout=self._idle_timeout, - error_policy=self._retry_policy, + retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, - receive_settle_mode=uamqp.constants.ReceiverSettleMode.ReceiveAndDelete, - auto_complete=False, - properties=properties, + properties=create_properties( + self._client._config.user_agent, amqp_transport=self._amqp_transport # pylint:disable=protected-access + ), desired_capabilities=desired_capabilities, + streaming_receive=True, + message_received_callback=self._message_received, ) - self._handler._streaming_receive = True # pylint:disable=protected-access - self._handler._message_received_callback = ( # pylint:disable=protected-access - self._message_received - ) - - def _open_with_retry(self): - # type: () -> None + def _open_with_retry(self) -> None: self._do_retryable_operation(self._open, operation_need_param=False) - def _message_received(self, message): - # type: (uamqp.Message) -> None + def _message_received(self, message: uamqp_Message) -> None: # pylint:disable=protected-access - self._message_buffer.appendleft(message) + self._message_buffer.append(message) def _next_message_in_buffer(self): # pylint:disable=protected-access - message = self._message_buffer.pop() + message = self._message_buffer.popleft() event_data = EventData._from_message(message) self._last_received_event = event_data return event_data - def _open(self): - # type: () -> bool - """Open the EventHubConsumer/EventHubProducer using the supplied connection.""" + def _open(self) -> bool: + """Open the EventHubConsumer/EventHubProducer using the supplied connection. + """ # pylint: disable=protected-access if not self.running: if self._handler: self._handler.close() auth = self._client._create_auth() self._create_handler(auth) - self._handler.open( - connection=self._client._conn_manager.get_connection( - self._client._address.hostname, auth - ) # pylint: disable=protected-access - ) - self.handler_ready = False + self._handler.open() + while not self._handler.client_ready(): + time.sleep(0.05) + self.handler_ready = True self.running = True - if not self.handler_ready: - if self._handler.client_ready(): # type: ignore - self.handler_ready = True return self.handler_ready def receive(self, batch=False, max_batch_size=300, max_wait_time=None): @@ -211,18 +185,17 @@ def receive(self, batch=False, max_batch_size=300, max_wait_time=None): self._receive_start_time = self._receive_start_time or time.time() deadline = self._receive_start_time + ( max_wait_time or 0 - ) # max_wait_time can be None + ) if len(self._message_buffer) < max_batch_size: while retried_times <= max_retries: try: if self._open(): - self._handler.do_work() # type: ignore + self._handler.do_work(batch=self._prefetch) # type: ignore break except Exception as exception: # pylint: disable=broad-except if ( - isinstance(exception, uamqp.errors.LinkDetach) - and exception.condition # pylint: disable=no-member - == uamqp.constants.ErrorCodes.LinkStolen + isinstance(exception, self._amqp_transport.AMQP_LINK_ERROR) + and exception.condition == self._amqp_transport.LINK_STOLEN_CONDITION # pylint: disable=no-member ): raise self._handle_exception(exception) if not self.running: # exit by close diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py index f45597f06c6b..5b48324bbe05 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py @@ -146,6 +146,7 @@ def __init__( **kwargs # type: Any ): # type: (...) -> None + self._checkpoint_store = kwargs.pop("checkpoint_store", None) self._load_balancing_interval = kwargs.pop("load_balancing_interval", None) if self._load_balancing_interval is None: @@ -210,6 +211,7 @@ def _create_consumer( prefetch=prefetch, idle_timeout=self._idle_timeout, track_last_enqueued_event_properties=track_last_enqueued_event_properties, + amqp_transport=self._amqp_transport, ) return handler @@ -222,9 +224,6 @@ def from_connection_string(cls, conn_str, consumer_group, **kwargs): :param str consumer_group: Receive events from the Event Hub for this consumer group. :keyword str eventhub_name: The path of the specific Event Hub to connect the client to. :keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`. - :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following - keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). - Additionally the following keys may also be present: `'username', 'password'`. :keyword float auth_timeout: The time in seconds to wait for a token to be authorized by the service. The default value is 60 seconds. If set to 0, no timeout will be enforced from the client. :keyword str user_agent: If specified, this will be added in front of the user agent string. @@ -254,6 +253,9 @@ def from_connection_string(cls, conn_str, consumer_group, **kwargs): If the port 5671 is unavailable/blocked in the network environment, `TransportType.AmqpOverWebsocket` could be used instead which uses port 443 for communication. :paramtype transport_type: ~azure.eventhub.TransportType + :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). + Additionally the following keys may also be present: `'username', 'password'`. :keyword checkpoint_store: A manager that stores the partition load-balancing and checkpoint data when receiving events. The checkpoint store will be used in both cases of receiving from all partitions or a single partition. In the latter case load-balancing does not apply. @@ -285,9 +287,9 @@ def from_connection_string(cls, conn_str, consumer_group, **kwargs): :keyword str connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. - :rtype: ~azure.eventhub.EventHubConsumerClient + .. admonition:: Example: .. literalinclude:: ../samples/sync_samples/sample_code_eventhub.py diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index 9bda02cf1d7d..4e990bf78e50 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -2,11 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations import uuid import logging -import time import threading from typing import ( Iterable, @@ -18,17 +17,10 @@ TYPE_CHECKING, ) # pylint: disable=unused-import -from uamqp import types, constants, errors -from uamqp import SendClient - -from azure.core.tracing import AbstractSpan - -from .exceptions import _error_handler, OperationTimeoutError from ._common import EventData, EventDataBatch from ._client_base import ConsumerProducerMixin from ._utils import ( create_properties, - set_message_partition_key, trace_message, send_context_manager, transform_outbound_single_message, @@ -39,19 +31,29 @@ _LOGGER = logging.getLogger(__name__) if TYPE_CHECKING: - from uamqp.authentication import JWTTokenAuth # pylint: disable=ungrouped-imports + from azure.core.tracing import AbstractSpan + + from uamqp import constants as uamqp_constants, SendClient as uamqp_SendClient + from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth + from ._transport._base import AmqpTransport from ._producer_client import EventHubProducerClient +_LOGGER = logging.getLogger(__name__) -def _set_partition_key(event_datas, partition_key): - # type: (Iterable[EventData], AnyStr) -> Iterable[EventData] + +def _set_partition_key( + event_datas: Iterable[EventData], + partition_key: AnyStr, + amqp_transport: AmqpTransport, +) -> Iterable[EventData]: for ed in iter(event_datas): - set_message_partition_key(ed.message, partition_key) + amqp_transport.set_message_partition_key(ed._message, partition_key) # pylint: disable=protected-access yield ed -def _set_trace_message(event_datas, parent_span=None): - # type: (Iterable[EventData], Optional[AbstractSpan]) -> Iterable[EventData] +def _set_trace_message( + event_datas: Iterable[EventData], parent_span: Optional["AbstractSpan"] = None +) -> Iterable[EventData]: for ed in iter(event_datas): trace_message(ed, parent_span) yield ed @@ -82,8 +84,11 @@ class EventHubProducer( Default value is `True`. """ - def __init__(self, client, target, **kwargs): - # type: (EventHubProducerClient, str, Any) -> None + def __init__( + self, client: "EventHubProducerClient", target: str, **kwargs: Any + ) -> None: + + self._amqp_transport = kwargs.pop("amqp_transport") partition = kwargs.get("partition", None) send_timeout = kwargs.get("send_timeout", 60) keep_alive = kwargs.get("keep_alive", None) @@ -98,83 +103,59 @@ def __init__(self, client, target, **kwargs): self._target = target self._partition = partition self._timeout = send_timeout - self._idle_timeout = (idle_timeout * 1000) if idle_timeout else None + self._idle_timeout = ( + (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) + if idle_timeout + else None + ) self._error = None self._keep_alive = keep_alive self._auto_reconnect = auto_reconnect - self._retry_policy = errors.ErrorPolicy( - max_retries=self._client._config.max_retries, - on_error=_error_handler, # pylint: disable=protected-access + self._retry_policy = self._amqp_transport.create_retry_policy( + config=self._client._config ) self._reconnect_backoff = 1 - self._name = "EHProducer-{}".format(uuid.uuid4()) - self._unsent_events = [] # type: List[Any] + self._name = f"EHProducer-{uuid.uuid4()}" + self._unsent_events: List[Any] = [] if partition: self._target += "/Partitions/" + partition - self._name += "-partition{}".format(partition) - self._handler = None # type: Optional[SendClient] - self._outcome = None # type: Optional[constants.MessageSendResult] - self._condition = None # type: Optional[Exception] + self._name += f"-partition{partition}" + self._handler: Optional[uamqp_SendClient] = None + self._outcome: Optional[uamqp_constants.MessageSendResult] = None + self._condition: Optional[Exception] = None self._lock = threading.Lock() - self._link_properties = { - types.AMQPSymbol(TIMEOUT_SYMBOL): types.AMQPLong(int(self._timeout * 1000)) - } - - def _create_handler(self, auth): - # type: (JWTTokenAuth) -> None - self._handler = SendClient( - self._target, + self._link_properties = self._amqp_transport.create_link_properties( + {TIMEOUT_SYMBOL: int(self._timeout * 1000)} + ) + + def _create_handler( + self, auth: uamqp_JWTTokenAuth + ) -> None: + self._handler = self._amqp_transport.create_send_client( + config=self._client._config, # pylint:disable=protected-access + target=self._target, auth=auth, - debug=self._client._config.network_tracing, # pylint:disable=protected-access - msg_timeout=self._timeout * 1000, + network_trace=self._client._config.network_tracing, # pylint:disable=protected-access idle_timeout=self._idle_timeout, - error_policy=self._retry_policy, + retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, link_properties=self._link_properties, properties=create_properties( - self._client._config.user_agent # pylint: disable=protected-access + self._client._config.user_agent, # pylint: disable=protected-access + amqp_transport=self._amqp_transport, ), + msg_timeout=self._timeout * 1000, ) - def _open_with_retry(self): - # type: () -> None + def _open_with_retry(self) -> None: return self._do_retryable_operation(self._open, operation_need_param=False) - def _set_msg_timeout(self, timeout_time, last_exception): - # type: (Optional[float], Optional[Exception]) -> None - if not timeout_time: - return - remaining_time = timeout_time - time.time() - if remaining_time <= 0.0: - if last_exception: - error = last_exception - else: - error = OperationTimeoutError("Send operation timed out") - _LOGGER.info("%r send operation timed out. (%r)", self._name, error) - raise error - self._handler._msg_timeout = remaining_time * 1000 # type: ignore # pylint: disable=protected-access - - def _send_event_data(self, timeout_time=None, last_exception=None): - # type: (Optional[float], Optional[Exception]) -> None - if self._unsent_events: - self._open() - self._set_msg_timeout(timeout_time, last_exception) - self._handler.queue_message(*self._unsent_events) # type: ignore - self._handler.wait() # type: ignore - self._unsent_events = self._handler.pending_messages # type: ignore - if self._outcome != constants.MessageSendResult.Ok: - if self._outcome == constants.MessageSendResult.Timeout: - self._condition = OperationTimeoutError("Send operation timed out") - if self._condition: - raise self._condition - - def _send_event_data_with_retry(self, timeout=None): - # type: (Optional[float]) -> None - return self._do_retryable_operation(self._send_event_data, timeout=timeout) - - def _on_outcome(self, outcome, condition): - # type: (constants.MessageSendResult, Optional[Exception]) -> None + def _on_outcome( + self, + outcome: "uamqp_constants.MessageSendResult", + condition: Optional[Exception], + ) -> None: """ Called when the outcome is received for a delivery. @@ -186,19 +167,33 @@ def _on_outcome(self, outcome, condition): self._outcome = outcome self._condition = condition + def _send_event_data( + self, + timeout_time: Optional[float] = None, + last_exception: Optional[Exception] = None, + ) -> None: + if self._unsent_events: + self._amqp_transport.send_messages( + self, timeout_time, last_exception, _LOGGER + ) + + def _send_event_data_with_retry(self, timeout: Optional[float] = None) -> None: + return self._do_retryable_operation(self._send_event_data, timeout=timeout) + def _wrap_eventdata( self, - event_data, # type: Union[EventData, AmqpAnnotatedMessage, EventDataBatch, Iterable[EventData]] - span, # type: Optional[AbstractSpan] - partition_key, # type: Optional[AnyStr] - ): - # type: (...) -> Union[EventData, EventDataBatch] + event_data: Union[EventData, EventDataBatch, Iterable[EventData]], + span: Optional["AbstractSpan"], + partition_key: Optional[AnyStr], + ) -> Union[EventData, EventDataBatch]: if isinstance(event_data, (EventData, AmqpAnnotatedMessage)): outgoing_event_data = transform_outbound_single_message( - event_data, EventData + event_data, EventData, self._amqp_transport.to_outgoing_amqp_message ) if partition_key: - set_message_partition_key(outgoing_event_data.message, partition_key) + self._amqp_transport.set_message_partition_key( + outgoing_event_data._message, partition_key # pylint: disable=protected-access + ) wrapper_event_data = outgoing_event_data trace_message(wrapper_event_data, span) else: @@ -217,24 +212,26 @@ def _wrap_eventdata( ) for ( event - ) in event_data.message._body_gen: # pylint: disable=protected-access + ) in event_data._message.data: # pylint: disable=protected-access trace_message(event, span) wrapper_event_data = event_data # type:ignore else: if partition_key: - event_data = _set_partition_key(event_data, partition_key) + event_data = _set_partition_key( + event_data, partition_key, self._amqp_transport + ) event_data = _set_trace_message(event_data, span) - wrapper_event_data = EventDataBatch._from_batch(event_data, partition_key) # type: ignore # pylint: disable=protected-access - wrapper_event_data.message.on_send_complete = self._on_outcome + wrapper_event_data = EventDataBatch._from_batch( # type: ignore # pylint: disable=protected-access + event_data, self._amqp_transport, partition_key=partition_key + ) return wrapper_event_data def send( self, - event_data, # type: Union[EventData, AmqpAnnotatedMessage, EventDataBatch, Iterable[EventData]] - partition_key=None, # type: Optional[AnyStr] - timeout=None, # type: Optional[float] - ): - # type:(...) -> None + event_data: Union[EventData, EventDataBatch, Iterable[EventData]], + partition_key: Optional[AnyStr] = None, + timeout: Optional[float] = None, + ) -> None: """ Sends an event data and blocks until acknowledgement is received or operation times out. @@ -269,17 +266,14 @@ def send( if not wrapper_event_data: return - self._unsent_events = [wrapper_event_data.message] - + self._unsent_events = [wrapper_event_data._message] # pylint: disable=protected-access if child: self._client._add_span_request_attributes( # pylint: disable=protected-access child ) - self._send_event_data_with_retry(timeout=timeout) - def close(self): - # type:() -> None + def close(self) -> None: """ Close down the handler. If the handler has already closed, this will be a no op. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index 595e7736bb40..3f00c3d501ae 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -19,9 +19,11 @@ ) from typing_extensions import Literal -from uamqp import constants - +from .exceptions import ConnectError, EventHubError +from .amqp import AmqpAnnotatedMessage from ._client_base import ClientBase +from ._producer import EventHubProducer +from ._constants import ALL_PARTITIONS, MAX_MESSAGE_LENGTH_BYTES from ._common import EventDataBatch, EventData from ._constants import ALL_PARTITIONS from ._producer import EventHubProducer @@ -250,6 +252,7 @@ def _buffered_send(self, events, **kwargs): max_wait_time=self._max_wait_time, max_buffer_length=self._max_buffer_length, executor=self._executor, + amqp_transport=self._amqp_transport ) self._buffered_producer_dispatcher.enqueue_events(events, **kwargs) @@ -322,8 +325,10 @@ def _get_max_message_size(self): EventHubProducer, self._producers[ALL_PARTITIONS] )._open_with_retry() self._max_message_size_on_link = ( - self._producers[ALL_PARTITIONS]._handler.message_handler._link.peer_max_message_size # type: ignore - or constants.MAX_MESSAGE_LENGTH_BYTES + self._amqp_transport.get_remote_max_message_size( + self._producers[ALL_PARTITIONS]._handler # type: ignore + ) + or MAX_MESSAGE_LENGTH_BYTES ) def _start_producer(self, partition_id, send_timeout): @@ -364,6 +369,7 @@ def _create_producer(self, partition_id=None, send_timeout=None): partition=partition_id, send_timeout=send_timeout, idle_timeout=self._idle_timeout, + amqp_transport=self._amqp_transport, ) return handler @@ -477,6 +483,9 @@ def from_connection_string( If the port 5671 is unavailable/blocked in the network environment, `TransportType.AmqpOverWebsocket` could be used instead which uses port 443 for communication. :paramtype transport_type: ~azure.eventhub.TransportType + :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). + Additionally the following keys may also be present: `'username', 'password'`. :keyword str custom_endpoint_address: The custom endpoint address to use for establishing a connection to the Event Hubs service, allowing network requests to be routed through any application gateways or other paths needed for the host environment. Default is None. @@ -724,6 +733,7 @@ def create_batch(self, **kwargs): max_size_in_bytes=(max_size_in_bytes or self._max_message_size_on_link), partition_id=partition_id, partition_key=partition_key, + amqp_transport=self._amqp_transport, ) return event_data_batch diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/__init__.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py new file mode 100644 index 000000000000..38266f53027d --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py @@ -0,0 +1,223 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from abc import ABC, abstractmethod + +class AmqpTransport(ABC): + """ + Abstract class that defines a set of common methods needed by producer and consumer. + """ + # define constants + BATCH_MESSAGE = None + MAX_FRAME_SIZE_BYTES = None + IDLE_TIMEOUT_FACTOR = None + MESSAGE = None + + # define symbols + PRODUCT_SYMBOL = None + VERSION_SYMBOL = None + FRAMEWORK_SYMBOL = None + PLATFORM_SYMBOL = None + USER_AGENT_SYMBOL = None + PROP_PARTITION_KEY_AMQP_SYMBOL = None + + # errors + AMQP_LINK_ERROR = None + LINK_STOLEN_CONDITION = None + MGMT_AUTH_EXCEPTION = None + CONNECTION_ERROR = None + AMQP_CONNECTION_ERROR = None + + @staticmethod + @abstractmethod + def to_outgoing_amqp_message(annotated_message): + """ + Converts an AmqpAnnotatedMessage into an Amqp Message. + :param AmqpAnnotatedMessage annotated_message: AmqpAnnotatedMessage to convert. + :rtype: uamqp.Message or pyamqp.Message + """ + + @staticmethod + @abstractmethod + def get_message_encoded_size(message): + """ + Gets the message encoded size given an underlying Message. + :param uamqp.Message or pyamqp.Message message: Message to get encoded size of. + :rtype: int + """ + + @staticmethod + @abstractmethod + def get_remote_max_message_size(handler): + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + :rtype: int + """ + + @staticmethod + @abstractmethod + def create_retry_policy(config): + """ + Creates the error retry policy. + :param ~azure.eventhub._configuration.Configuration config: Configuration. + """ + + @staticmethod + @abstractmethod + def create_link_properties(link_properties): + """ + Creates and returns the link properties. + :param dict[bytes, int] link_properties: The dict of symbols and corresponding values. + :rtype: dict + """ + + @staticmethod + @abstractmethod + def create_send_client(*, config, **kwargs): + """ + Creates and returns the send client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + + @staticmethod + @abstractmethod + def send_messages(producer, timeout_time, last_exception, logger): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + + @staticmethod + @abstractmethod + def set_message_partition_key(message, partition_key, **kwargs): + """Set the partition key as an annotation on a uamqp message. + + :param message: The message to update. + :param str partition_key: The partition key value. + :rtype: None + """ + + @staticmethod + @abstractmethod + def add_batch(batch_message, outgoing_event_data, event_data): + """ + Add EventData to the data body of the BatchMessage. + :param batch_message: BatchMessage to add data to. + :param outgoing_event_data: Transformed EventData for sending. + :param event_data: EventData to add to internal batch events. uamqp use only. + :rtype: None + """ + + @staticmethod + @abstractmethod + def create_source(source, offset, selector): + """ + Creates and returns the Source. + + :param str source: Required. + :param int offset: Required. + :param bytes selector: Required. + """ + + @staticmethod + @abstractmethod + def create_receive_client(*, config, **kwargs): + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword Source source: Required. The source. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. Missing in pyamqp. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + + @staticmethod + @abstractmethod + def open_receive_client(*, handler, client, auth): + """ + Opens the receive client. + :param ReceiveClient handler: The receive client. + :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. + """ + + @staticmethod + @abstractmethod + def create_token_auth(auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. + + :keyword bool update_token: Whether to update token. If not updating token, + then pass 300 to refresh_window. Only used by uamqp. + """ + + @staticmethod + @abstractmethod + def create_mgmt_client(address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ + + @staticmethod + @abstractmethod + def get_updated_token(mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + + @staticmethod + @abstractmethod + def mgmt_client_request(mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + + @staticmethod + @abstractmethod + def get_error(error, message, *, condition=None): + """ + Gets error and passes in error message, and, if applicable, condition. + :param error: The error to raise. + :param str message: Error message. + :param condition: Optional error condition. Will not be used by uamqp. + """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py new file mode 100644 index 000000000000..d928620675c0 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -0,0 +1,552 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import time +import logging +from typing import Optional, Union, Any + +try: + from uamqp import ( + BatchMessage, + constants, + MessageBodyType, + Message, + types, + SendClient, + ReceiveClient, + Source, + utils, + authentication, + AMQPClient, + compat, + errors, + ) + from uamqp.message import ( + MessageHeader, + MessageProperties, + ) + uamqp_installed = True +except ImportError: + uamqp_installed = False + +from ._base import AmqpTransport +from ..amqp._constants import AmqpMessageBodyType +from .._constants import ( + NO_RETRY_ERRORS, + PROP_PARTITION_KEY, +) + +from ..exceptions import ( + ConnectError, + EventDataError, + EventDataSendError, + OperationTimeoutError, + EventHubError, + AuthenticationError, + ConnectionLostError, + EventDataError, + EventDataSendError, +) + +_LOGGER = logging.getLogger(__name__) + +if uamqp_installed: + def _error_handler(error): + """ + Called internally when an event has failed to send so we + can parse the error to determine whether we should attempt + to retry sending the event again. + Returns the action to take according to error type. + + :param error: The error received in the send attempt. + :type error: Exception + :rtype: ~uamqp.errors.ErrorAction + """ + if error.condition == b"com.microsoft:server-busy": + return errors.ErrorAction(retry=True, backoff=4) + if error.condition == b"com.microsoft:timeout": + return errors.ErrorAction(retry=True, backoff=2) + if error.condition == b"com.microsoft:operation-cancelled": + return errors.ErrorAction(retry=True) + if error.condition == b"com.microsoft:container-close": + return errors.ErrorAction(retry=True, backoff=4) + if error.condition in NO_RETRY_ERRORS: + return errors.ErrorAction(retry=False) + return errors.ErrorAction(retry=True) + + + class UamqpTransport(AmqpTransport): + """ + Class which defines uamqp-based methods used by the producer and consumer. + """ + # define constants + BATCH_MESSAGE = BatchMessage + MAX_FRAME_SIZE_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES + IDLE_TIMEOUT_FACTOR = 1000 + MESSAGE = Message + + # define symbols + PRODUCT_SYMBOL = types.AMQPSymbol("product") + VERSION_SYMBOL = types.AMQPSymbol("version") + FRAMEWORK_SYMBOL = types.AMQPSymbol("framework") + PLATFORM_SYMBOL = types.AMQPSymbol("platform") + USER_AGENT_SYMBOL = types.AMQPSymbol("user-agent") + PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) + + # define errors and conditions + AMQP_LINK_ERROR = errors.LinkDetach + LINK_STOLEN_CONDITION = constants.ErrorCodes.LinkStolen + AUTH_EXCEPTION = errors.AuthenticationException + CONNECTION_ERROR = ConnectError + AMQP_CONNECTION_ERROR = errors.AMQPConnectionError + TIMEOUT_EXCEPTION = compat.TimeoutException + + @staticmethod + def to_outgoing_amqp_message(annotated_message): + """ + Converts an AmqpAnnotatedMessage into an Amqp Message. + :param AmqpAnnotatedMessage annotated_message: AmqpAnnotatedMessage to convert. + :rtype: uamqp.Message + """ + message_header = None + if annotated_message.header: + message_header = MessageHeader() + message_header.delivery_count = annotated_message.header.delivery_count + message_header.time_to_live = annotated_message.header.time_to_live + message_header.first_acquirer = annotated_message.header.first_acquirer + message_header.durable = annotated_message.header.durable + message_header.priority = annotated_message.header.priority + + message_properties = None + if annotated_message.properties: + message_properties = MessageProperties( + message_id=annotated_message.properties.message_id, + user_id=annotated_message.properties.user_id, + to=annotated_message.properties.to, + subject=annotated_message.properties.subject, + reply_to=annotated_message.properties.reply_to, + correlation_id=annotated_message.properties.correlation_id, + content_type=annotated_message.properties.content_type, + content_encoding=annotated_message.properties.content_encoding, + creation_time=int(annotated_message.properties.creation_time) + if annotated_message.properties.creation_time else None, + absolute_expiry_time=int(annotated_message.properties.absolute_expiry_time) + if annotated_message.properties.absolute_expiry_time else None, + group_id=annotated_message.properties.group_id, + group_sequence=annotated_message.properties.group_sequence, + reply_to_group_id=annotated_message.properties.reply_to_group_id, + encoding=annotated_message._encoding # pylint: disable=protected-access + ) + + # pylint: disable=protected-access + amqp_body_type = annotated_message.body_type + if amqp_body_type == AmqpMessageBodyType.DATA: + amqp_body_type = MessageBodyType.Data + amqp_body = list(annotated_message._data_body) + elif amqp_body_type == AmqpMessageBodyType.SEQUENCE: + amqp_body_type = MessageBodyType.Sequence + amqp_body = list(annotated_message._sequence_body) + else: + amqp_body_type = MessageBodyType.Value + amqp_body = annotated_message._value_body + + return Message( + body=amqp_body, + body_type=amqp_body_type, + header=message_header, + properties=message_properties, + application_properties=annotated_message.application_properties, + annotations=annotated_message.annotations, + delivery_annotations=annotated_message.delivery_annotations, + footer=annotated_message.footer + ) + + @staticmethod + def get_batch_message_encoded_size(message): + """ + Gets the batch message encoded size given an underlying Message. + :param uamqp.BatchMessage message: Message to get encoded size of. + :rtype: int + """ + return message.gather()[0].get_message_encoded_size() + + @staticmethod + def get_message_encoded_size(message): + """ + Gets the message encoded size given an underlying Message. + :param uamqp.Message message: Message to get encoded size of. + :rtype: int + """ + return message.get_message_encoded_size() + + @staticmethod + def get_remote_max_message_size(handler): + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + :rtype: int + """ + return handler.message_handler._link.peer_max_message_size # pylint:disable=protected-access + + @staticmethod + def create_retry_policy(config): + """ + Creates the error retry policy. + :param ~azure.eventhub._configuration.Configuration config: Configuration. + """ + return errors.ErrorPolicy(max_retries=config.max_retries, on_error=_error_handler) + + @staticmethod + def create_link_properties(link_properties): + """ + Creates and returns the link properties. + :param dict[bytes, int] link_properties: The dict of symbols and corresponding values. + :rtype: dict + """ + return {types.AMQPSymbol(symbol): types.AMQPLong(value) for (symbol, value) in link_properties.items()} + + @staticmethod + def create_send_client(*, config, **kwargs): # pylint:disable=unused-argument + """ + Creates and returns the uamqp SendClient. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + target = kwargs.pop("target") + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + + return SendClient( + target, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + **kwargs + ) + + @staticmethod + def _set_msg_timeout(producer, timeout_time, last_exception, logger): + if not timeout_time: + return + remaining_time = timeout_time - time.time() + if remaining_time <= 0.0: + if last_exception: + error = last_exception + else: + error = OperationTimeoutError("Send operation timed out") + logger.info("%r send operation timed out. (%r)", producer._name, error) # pylint: disable=protected-access + raise error + producer._handler._msg_timeout = remaining_time * 1000 # type: ignore # pylint: disable=protected-access + + @staticmethod + def send_messages(producer, timeout_time, last_exception, logger): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + # pylint: disable=protected-access + producer._open() + producer._unsent_events[0].on_send_complete = producer._on_outcome + UamqpTransport._set_msg_timeout(producer, timeout_time, last_exception, logger) + producer._handler.queue_message(*producer._unsent_events) # type: ignore + producer._handler.wait() # type: ignore + producer._unsent_events = producer._handler.pending_messages # type: ignore + if producer._outcome != constants.MessageSendResult.Ok: + if producer._outcome == constants.MessageSendResult.Timeout: + producer._condition = OperationTimeoutError("Send operation timed out") + if producer._condition: + raise producer._condition + + @staticmethod + def set_message_partition_key(message, partition_key, **kwargs): # pylint:disable=unused-argument + # type: (Message, Optional[Union[bytes, str]], Any) -> Message + """Set the partition key as an annotation on a uamqp message. + + :param ~uamqp.Message message: The message to update. + :param str partition_key: The partition key value. + :rtype: Message + """ + if partition_key: + annotations = message.annotations + if annotations is None: + annotations = {} + annotations[ + UamqpTransport.PROP_PARTITION_KEY_AMQP_SYMBOL # TODO: see if setting non-amqp symbol is valid + ] = partition_key + header = MessageHeader() + header.durable = True + message.annotations = annotations + message.header = header + return message + + @staticmethod + def add_batch(batch_message, outgoing_event_data, event_data): + """ + Add EventData to the data body of the BatchMessage. + :param batch_message: BatchMessage to add data to. + :param outgoing_event_data: Transformed EventData for sending. + :param event_data: EventData to add to internal batch events. uamqp use only. + :rtype: None + """ + # pylint: disable=protected-access + batch_message._internal_events.append(event_data) + batch_message._message._body_gen.append( + outgoing_event_data._message + ) + + @staticmethod + def create_source(source, offset, selector): + """ + Creates and returns the Source. + + :param str source: Required. + :param int offset: Required. + :param bytes selector: Required. + """ + source = Source(source) + if offset is not None: + source.set_filter(selector) + return source + + @staticmethod + def create_receive_client(*, config, **kwargs): + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str source: Required. The source. + :keyword str offset: Required. + :keyword str offset_inclusive: Required. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + + source = kwargs.pop("source") + symbol_array = kwargs.pop("desired_capabilities") + desired_capabilities = None + if symbol_array: + symbol_array = [types.AMQPSymbol(symbol) for symbol in symbol_array] + desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + link_credit = kwargs.pop("link_credit") + streaming_receive = kwargs.pop("streaming_receive") + message_received_callback = kwargs.pop("message_received_callback") + + client = ReceiveClient( + source, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + desired_capabilities=desired_capabilities, + prefetch=link_credit, + receive_settle_mode=constants.ReceiverSettleMode.ReceiveAndDelete, + auto_complete=False, + **kwargs + ) + # pylint:disable=protected-access + client._streaming_receive = streaming_receive + client._message_received_callback = (message_received_callback) + return client + + @staticmethod + def open_receive_client(*, handler, client, auth): + """ + Opens the receive client and returns ready status. + :param ReceiveClient handler: The receive client. + :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. + :param auth: Auth. + :rtype: bool + """ + # pylint:disable=protected-access + handler.open(connection=client._conn_manager.get_connection( + client._address.hostname, auth + )) + + @staticmethod + def create_token_auth(auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. + + :keyword bool update_token: Required. Whether to update token. If not updating token, + then pass 300 to refresh_window. + """ + update_token = kwargs.pop("update_token") + refresh_window = 300 + if update_token: + refresh_window = 0 + + token_auth = authentication.JWTTokenAuth( + auth_uri, + auth_uri, + get_token, + token_type=token_type, + timeout=config.auth_timeout, + http_proxy=config.http_proxy, + transport_type=config.transport_type, + custom_endpoint_hostname=config.custom_endpoint_hostname, + port=config.connection_port, + verify=config.connection_verify, + refresh_window=refresh_window + ) + if update_token: + token_auth.update_token() + return token_auth + + @staticmethod + def create_mgmt_client(address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ + + mgmt_target = f"amqps://{address.hostname}{address.path}" + return AMQPClient( + mgmt_target, + auth=mgmt_auth, + debug=config.network_tracing + ) + + @staticmethod + def get_updated_token(mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + return mgmt_auth.token + + @staticmethod + def mgmt_client_request(mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + operation_type = kwargs.pop("operation_type") + operation = kwargs.pop("operation") + return mgmt_client.mgmt_request( + mgmt_msg, + operation, + op_type=operation_type, + **kwargs + ) + + @staticmethod + def get_error(error, message, *, condition=None): # pylint: disable=unused-argument + """ + Gets error and passes in error message, and, if applicable, condition. + :param error: The error to raise. + :param str message: Error message. + :param condition: Optional error condition. Will not be used by uamqp. + """ + return error(message) + + @staticmethod + def _create_eventhub_exception(exception): + if isinstance(exception, errors.AuthenticationException): + error = AuthenticationError(str(exception), exception) + elif isinstance(exception, errors.VendorLinkDetach): + error = ConnectError(str(exception), exception) + elif isinstance(exception, errors.LinkDetach): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.ConnectionClose): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.MessageHandlerError): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.AMQPConnectionError): + error_type = ( + AuthenticationError + if str(exception).startswith("Unable to open authentication session") + else ConnectError + ) + error = error_type(str(exception), exception) + elif isinstance(exception, compat.TimeoutException): + error = ConnectionLostError(str(exception), exception) + else: + error = EventHubError(str(exception), exception) + return error + + @staticmethod + def _handle_exception( + exception, closable + ): # pylint:disable=too-many-branches, too-many-statements + try: # closable is a producer/consumer object + name = closable._name # pylint: disable=protected-access + except AttributeError: # closable is an client object + name = closable._container_id # pylint: disable=protected-access + if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise + _LOGGER.info("%r stops due to keyboard interrupt", name) + closable._close_connection() # pylint:disable=protected-access + raise exception + elif isinstance(exception, EventHubError): + closable._close_handler() # pylint:disable=protected-access + raise exception + elif isinstance( + exception, + ( + errors.MessageAccepted, + errors.MessageAlreadySettled, + errors.MessageModified, + errors.MessageRejected, + errors.MessageReleased, + errors.MessageContentTooLarge, + ), + ): + _LOGGER.info("%r Event data error (%r)", name, exception) + error = EventDataError(str(exception), exception) + raise error + elif isinstance(exception, errors.MessageException): + _LOGGER.info("%r Event data send error (%r)", name, exception) + error = EventDataSendError(str(exception), exception) + raise error + else: + if isinstance(exception, errors.AuthenticationException): + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + elif isinstance(exception, errors.LinkDetach): + if hasattr(closable, "_close_handler"): + closable._close_handler() # pylint:disable=protected-access + elif isinstance(exception, errors.ConnectionClose): + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + elif isinstance(exception, errors.MessageHandlerError): + if hasattr(closable, "_close_handler"): + closable._close_handler() # pylint:disable=protected-access + else: # errors.AMQPConnectionError, compat.TimeoutException + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + return UamqpTransport._create_eventhub_exception(exception) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index 744ebfaf3df0..410bd2ff5536 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations from contextlib import contextmanager import sys @@ -20,20 +20,17 @@ Iterable, Tuple, Mapping, + Callable ) import six -from uamqp import types -from uamqp.message import MessageHeader - from azure.core.settings import settings from azure.core.tracing import SpanKind, Link from .amqp import AmqpAnnotatedMessage, AmqpMessageHeader from ._version import VERSION from ._constants import ( - PROP_PARTITION_KEY_AMQP_SYMBOL, MAX_USER_AGENT_LENGTH, USER_AGENT_PREFIX, PROP_LAST_ENQUEUED_SEQUENCE_NUMBER, @@ -41,11 +38,19 @@ PROP_RUNTIME_INFO_RETRIEVAL_TIME_UTC, PROP_LAST_ENQUEUED_OFFSET, PROP_TIMESTAMP, + PROP_PARTITION_KEY ) +# TODO: remove after fixing up async +from uamqp import types +from uamqp.message import MessageHeader +PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) + + if TYPE_CHECKING: # pylint: disable=ungrouped-imports - from uamqp import Message + from ._transport._base import AmqpTransport + from uamqp import types as uamqp_types from azure.core.tracing import AbstractSpan from azure.core.credentials import AzureSasCredential from ._common import EventData @@ -87,8 +92,9 @@ def utc_from_timestamp(timestamp): return datetime.datetime.fromtimestamp(timestamp, tz=TZ_UTC) -def create_properties(user_agent=None): - # type: (Optional[str]) -> Dict[types.AMQPSymbol, str] +def create_properties( + user_agent: Optional[str] = None, *, amqp_transport: AmqpTransport +) -> Dict[uamqp_types.AMQPSymbol, str]: """ Format the properties with which to instantiate the connection. This acts like a user agent over HTTP. @@ -96,32 +102,37 @@ def create_properties(user_agent=None): :rtype: dict """ properties = {} - properties[types.AMQPSymbol("product")] = USER_AGENT_PREFIX - properties[types.AMQPSymbol("version")] = VERSION - framework = "Python/{}.{}.{}".format( - sys.version_info[0], sys.version_info[1], sys.version_info[2] - ) - properties[types.AMQPSymbol("framework")] = framework + properties[amqp_transport.PRODUCT_SYMBOL] = USER_AGENT_PREFIX + properties[amqp_transport.VERSION_SYMBOL] = VERSION + framework = f"Python/{sys.version_info[0]}.{sys.version_info[1]}.{sys.version_info[2]}" + properties[amqp_transport.FRAMEWORK_SYMBOL] = framework platform_str = platform.platform() - properties[types.AMQPSymbol("platform")] = platform_str + properties[amqp_transport.PLATFORM_SYMBOL] = platform_str - final_user_agent = "{}/{} {} ({})".format( - USER_AGENT_PREFIX, VERSION, framework, platform_str - ) + final_user_agent = f"{USER_AGENT_PREFIX}/{VERSION} {framework} ({platform_str})" if user_agent: - final_user_agent = "{} {}".format(user_agent, final_user_agent) + final_user_agent = f"{user_agent} {final_user_agent}" if len(final_user_agent) > MAX_USER_AGENT_LENGTH: raise ValueError( - "The user-agent string cannot be more than {} in length." - "Current user_agent string is: {} with length: {}".format( - MAX_USER_AGENT_LENGTH, final_user_agent, len(final_user_agent) - ) + f"The user-agent string cannot be more than {MAX_USER_AGENT_LENGTH} in length." + f"Current user_agent string is: {final_user_agent} with length: {len(final_user_agent)}" ) - properties[types.AMQPSymbol("user-agent")] = final_user_agent + properties[amqp_transport.USER_AGENT_SYMBOL] = final_user_agent return properties +@contextmanager +def send_context_manager(): + span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] + + if span_impl_type is not None: + with span_impl_type(name="Azure.EventHubs.send", kind=SpanKind.CLIENT) as child: + yield child + else: + yield None + +# TODO: delete after async unit tests have been refactored def set_event_partition_key(event, partition_key): # type: (Union[AmqpAnnotatedMessage, EventData], Optional[Union[bytes, str]]) -> None if not partition_key: @@ -236,15 +247,13 @@ def event_position_selector(value, inclusive=False): value.microsecond / 1000 ) return ( - "amqp.annotation.x-opt-enqueued-time {} '{}'".format( - operator, int(timestamp) - ) + f"amqp.annotation.x-opt-enqueued-time {operator} '{int(timestamp)}'" ).encode("utf-8") elif isinstance(value, six.integer_types): return ( - "amqp.annotation.x-opt-sequence-number {} '{}'".format(operator, value) + f"amqp.annotation.x-opt-sequence-number {operator} '{value}'" ).encode("utf-8") - return ("amqp.annotation.x-opt-offset {} '{}'".format(operator, value)).encode( + return (f"amqp.annotation.x-opt-offset {operator} '{value}'").encode( "utf-8" ) @@ -259,23 +268,23 @@ def get_last_enqueued_event_properties(event_data): if event_data._last_enqueued_event_properties: return event_data._last_enqueued_event_properties - if event_data.message.delivery_annotations: - sequence_number = event_data.message.delivery_annotations.get( + if event_data._message.delivery_annotations: + sequence_number = event_data._message.delivery_annotations.get( PROP_LAST_ENQUEUED_SEQUENCE_NUMBER, None ) - enqueued_time_stamp = event_data.message.delivery_annotations.get( + enqueued_time_stamp = event_data._message.delivery_annotations.get( PROP_LAST_ENQUEUED_TIME_UTC, None ) if enqueued_time_stamp: enqueued_time_stamp = utc_from_timestamp(float(enqueued_time_stamp) / 1000) - retrieval_time_stamp = event_data.message.delivery_annotations.get( + retrieval_time_stamp = event_data._message.delivery_annotations.get( PROP_RUNTIME_INFO_RETRIEVAL_TIME_UTC, None ) if retrieval_time_stamp: retrieval_time_stamp = utc_from_timestamp( float(retrieval_time_stamp) / 1000 ) - offset_bytes = event_data.message.delivery_annotations.get( + offset_bytes = event_data._message.delivery_annotations.get( PROP_LAST_ENQUEUED_OFFSET, None ) offset = offset_bytes.decode("UTF-8") if offset_bytes else None @@ -301,8 +310,8 @@ def parse_sas_credential(credential): return (sas, expiry) -def transform_outbound_single_message(message, message_type): - # type: (Union[AmqpAnnotatedMessage, EventData], Type[EventData]) -> EventData +def transform_outbound_single_message(message, message_type, to_outgoing_amqp_message): + # type: (Union[AmqpAnnotatedMessage, EventData], Type[EventData], Callable) -> EventData """ This method serves multiple goals: 1. update the internal message to reflect any updates to settable properties on EventData @@ -314,14 +323,16 @@ def transform_outbound_single_message(message, message_type): :rtype: EventData """ try: - # EventData # pylint: disable=protected-access - return message._to_outgoing_message() # type: ignore + # EventData + message._message = to_outgoing_amqp_message(message.raw_amqp_message) + return message # type: ignore except AttributeError: - # AmqpAnnotatedMessage # pylint: disable=protected-access + # AmqpAnnotatedMessage + amqp_message = to_outgoing_amqp_message(message) return message_type._from_message( - message=message._to_outgoing_amqp_message(), raw_amqp_message=message # type: ignore + message=amqp_message, raw_amqp_message=message # type: ignore ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py index f517c2385bd5..035c8aaa38ee 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py @@ -3,4 +3,4 @@ # Licensed under the MIT License. # ------------------------------------ -VERSION = "5.10.0" +VERSION = "5.10.1" diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py index 28a3c9e79fa4..f0846f9b49fd 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py @@ -4,11 +4,11 @@ # license information. # ------------------------------------------------------------------------- -from typing import Optional, Any, cast, Mapping, Dict +from __future__ import annotations +from typing import Optional, Any, cast, Mapping, Dict, Union -import uamqp - -from ._constants import AMQP_MESSAGE_BODY_TYPE_MAP, AmqpMessageBodyType +from ._amqp_utils import normalized_data_body, normalized_sequence_body +from ._constants import AmqpMessageBodyType from .._mixin import DictMixin @@ -19,11 +19,9 @@ class AmqpAnnotatedMessage(object): access to low-level AMQP message sections. There should be one and only one of either data_body, sequence_body or value_body being set as the body of the AmqpAnnotatedMessage; if more than one body is set, `ValueError` will be raised. - Please refer to the AMQP spec: http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-messaging-v1.0-os.html#section-message-format for more information on the message format. - :keyword data_body: The body consists of one or more data sections and each section contains opaque binary data. :paramtype data_body: Union[str, bytes, List[Union[str, bytes]]] :keyword sequence_body: The body consists of one or more sequence sections and @@ -47,12 +45,15 @@ class AmqpAnnotatedMessage(object): def __init__(self, **kwargs): # type: (Any) -> None - self._message = kwargs.pop("message", None) self._encoding = kwargs.pop("encoding", "UTF-8") + self._data_body = None + self._sequence_body = None + self._value_body = None # internal usage only for Event Hub received message - if self._message: - self._from_amqp_message(self._message) + message = kwargs.pop("message", None) + if message: + self._from_amqp_message(message) return # manually constructed AMQPAnnotatedMessage @@ -69,21 +70,17 @@ def __init__(self, **kwargs): "or value_body being set as the body of the AmqpAnnotatedMessage." ) - self._body = None self._body_type = None if "data_body" in kwargs: - self._body = kwargs.get("data_body") - self._body_type = uamqp.MessageBodyType.Data + self._data_body = normalized_data_body(kwargs.get("data_body")) + self._body_type = AmqpMessageBodyType.DATA elif "sequence_body" in kwargs: - self._body = kwargs.get("sequence_body") - self._body_type = uamqp.MessageBodyType.Sequence + self._sequence_body = normalized_sequence_body(kwargs.get("sequence_body")) + self._body_type = AmqpMessageBodyType.SEQUENCE elif "value_body" in kwargs: - self._body = kwargs.get("value_body") - self._body_type = uamqp.MessageBodyType.Value + self._value_body = kwargs.get("value_body") + self._body_type = AmqpMessageBodyType.VALUE - self._message = uamqp.message.Message( - body=self._body, body_type=self._body_type - ) header_dict = cast(Mapping, kwargs.get("header")) self._header = AmqpMessageHeader(**header_dict) if "header" in kwargs else None self._footer = kwargs.get("footer") @@ -95,11 +92,16 @@ def __init__(self, **kwargs): self._annotations = kwargs.get("annotations") self._delivery_annotations = kwargs.get("delivery_annotations") - def __str__(self): - return str(self._message) + def __str__(self) -> str: + if self._body_type == AmqpMessageBodyType.DATA: + return "".join(d.decode(self._encoding) for d in self._data_body) + elif self._body_type == AmqpMessageBodyType.SEQUENCE: + return str(self._sequence_body) + elif self._body_type == AmqpMessageBodyType.VALUE: + return str(self._value_body) + return "" - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: # pylint: disable=bare-except message_repr = "body={}".format(str(self)) message_repr += ", body_type={}".format(self.body_type) @@ -134,143 +136,80 @@ def __repr__(self): return "AmqpAnnotatedMessage({})".format(message_repr)[:1024] def _from_amqp_message(self, message): - # populate the properties from an uamqp message - self._properties = ( - AmqpMessageProperties( - message_id=message.properties.message_id, - user_id=message.properties.user_id, - to=message.properties.to, - subject=message.properties.subject, - reply_to=message.properties.reply_to, - correlation_id=message.properties.correlation_id, - content_type=message.properties.content_type, - content_encoding=message.properties.content_encoding, - absolute_expiry_time=message.properties.absolute_expiry_time, - creation_time=message.properties.creation_time, - group_id=message.properties.group_id, - group_sequence=message.properties.group_sequence, - reply_to_group_id=message.properties.reply_to_group_id, - ) - if message.properties - else None - ) - self._header = ( - AmqpMessageHeader( - delivery_count=message.header.delivery_count, - time_to_live=message.header.time_to_live, - first_acquirer=message.header.first_acquirer, - durable=message.header.durable, - priority=message.header.priority, - ) - if message.header - else None - ) - self._footer = message.footer - self._annotations = message.annotations - self._delivery_annotations = message.delivery_annotations - self._application_properties = message.application_properties - - def _to_outgoing_amqp_message(self): - message_header = None - if self.header: - message_header = uamqp.message.MessageHeader() - message_header.delivery_count = self.header.delivery_count - message_header.time_to_live = self.header.time_to_live - message_header.first_acquirer = self.header.first_acquirer - message_header.durable = self.header.durable - message_header.priority = self.header.priority - - message_properties = None - if self.properties: - message_properties = uamqp.message.MessageProperties( - message_id=self.properties.message_id, - user_id=self.properties.user_id, - to=self.properties.to, - subject=self.properties.subject, - reply_to=self.properties.reply_to, - correlation_id=self.properties.correlation_id, - content_type=self.properties.content_type, - content_encoding=self.properties.content_encoding, - creation_time=int(self.properties.creation_time) - if self.properties.creation_time - else None, - absolute_expiry_time=int(self.properties.absolute_expiry_time) - if self.properties.absolute_expiry_time - else None, - group_id=self.properties.group_id, - group_sequence=self.properties.group_sequence, - reply_to_group_id=self.properties.reply_to_group_id, - encoding=self._encoding, - ) - - amqp_body = self._message._body # pylint: disable=protected-access - if isinstance(amqp_body, uamqp.message.DataBody): - amqp_body_type = uamqp.MessageBodyType.Data - amqp_body = list(amqp_body.data) - elif isinstance(amqp_body, uamqp.message.SequenceBody): - amqp_body_type = uamqp.MessageBodyType.Sequence - amqp_body = list(amqp_body.data) + self._properties = AmqpMessageProperties( + message_id=message.properties.message_id, + user_id=message.properties.user_id, + to=message.properties.to, + subject=message.properties.subject, + reply_to=message.properties.reply_to, + correlation_id=message.properties.correlation_id, + content_type=message.properties.content_type, + content_encoding=message.properties.content_encoding, + absolute_expiry_time=message.properties.absolute_expiry_time, + creation_time=message.properties.creation_time, + group_id=message.properties.group_id, + group_sequence=message.properties.group_sequence, + reply_to_group_id=message.properties.reply_to_group_id, + ) if message.properties else None + self._header = AmqpMessageHeader( + delivery_count=message.header.delivery_count, + time_to_live=message.header.ttl, + first_acquirer=message.header.first_acquirer, + durable=message.header.durable, + priority=message.header.priority + ) if message.header else None + self._footer = message.footer if message.footer else {} + self._annotations = message.message_annotations if message.message_annotations else {} + self._delivery_annotations = message.delivery_annotations if message.delivery_annotations else {} + self._application_properties = message.application_properties if message.application_properties else {} + if message.data: + self._data_body = list(message.data) + self._body_type = AmqpMessageBodyType.DATA + elif message.sequence: + self._sequence_body = list(message.sequence) + self._body_type = AmqpMessageBodyType.SEQUENCE else: - # amqp_body is type of uamqp.message.ValueBody - amqp_body_type = uamqp.MessageBodyType.Value - amqp_body = amqp_body.data - - return uamqp.message.Message( - body=amqp_body, - body_type=amqp_body_type, - header=message_header, - properties=message_properties, - application_properties=self.application_properties, - annotations=self.annotations, - delivery_annotations=self.delivery_annotations, - footer=self.footer, - ) + self._value_body = message.value + self._body_type = AmqpMessageBodyType.VALUE @property - def body(self): + def body(self) -> Any: # type: () -> Any """The body of the Message. The format may vary depending on the body type: - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.DATA`, - the body could be bytes or Iterable[bytes]. - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.SEQUENCE`, - the body could be List or Iterable[List]. - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.VALUE`, - the body could be any type. - + For ~azure.eventhub.AmqpMessageBodyType.DATA, the body could be bytes or Iterable[bytes] + For ~azure.eventhub.AmqpMessageBodyType.SEQUENCE, the body could be List or Iterable[List] + For ~azure.eventhub.AmqpMessageBodyType.VALUE, the body could be any type. :rtype: Any """ - return self._message.get_data() + if self._body_type == AmqpMessageBodyType.DATA: # pylint:disable=no-else-return + return (i for i in self._data_body) + elif self._body_type == AmqpMessageBodyType.SEQUENCE: + return (i for i in self._sequence_body) + elif self._body_type == AmqpMessageBodyType.VALUE: + return self._value_body + return None @property - def body_type(self): - # type: () -> AmqpMessageBodyType + def body_type(self) -> AmqpMessageBodyType: """The body type of the underlying AMQP message. - - :rtype: ~azure.eventhub.amqp.AmqpMessageBodyType + rtype: ~azure.eventhub.amqp.AmqpMessageBodyType """ - return AMQP_MESSAGE_BODY_TYPE_MAP.get( - self._message._body.type, # pylint: disable=protected-access - AmqpMessageBodyType.VALUE, - ) + return self._body_type @property - def properties(self): - # type: () -> Optional[AmqpMessageProperties] + def properties(self) -> Optional[AmqpMessageProperties]: """ Properties to add to the message. - :rtype: Optional[~azure.eventhub.amqp.AmqpMessageProperties] """ return self._properties @properties.setter - def properties(self, value): - # type: (AmqpMessageProperties) -> None + def properties(self, value: AmqpMessageProperties) -> None: self._properties = value @property - def application_properties(self): - # type: () -> Optional[Dict] + def application_properties(self) -> Optional[Dict[Union[str, bytes], Any]]: """ Service specific application properties. @@ -279,13 +218,11 @@ def application_properties(self): return self._application_properties @application_properties.setter - def application_properties(self, value): - # type: (Dict) -> None + def application_properties(self, value: Optional[Dict[Union[str, bytes], Any]]) -> None: self._application_properties = value @property - def annotations(self): - # type: () -> Optional[Dict] + def annotations(self) -> Optional[Dict[Union[str, bytes], Any]]: """ Service specific message annotations. @@ -294,13 +231,11 @@ def annotations(self): return self._annotations @annotations.setter - def annotations(self, value): - # type: (Dict) -> None + def annotations(self, value: Optional[Dict[Union[str, bytes], Any]]) -> None: self._annotations = value @property - def delivery_annotations(self): - # type: () -> Optional[Dict] + def delivery_annotations(self) -> Optional[Dict[Union[str, bytes], Any]]: """ Delivery-specific non-standard properties at the head of the message. Delivery annotations convey information from the sending peer to the receiving peer. @@ -310,28 +245,23 @@ def delivery_annotations(self): return self._delivery_annotations @delivery_annotations.setter - def delivery_annotations(self, value): - # type: (Dict) -> None + def delivery_annotations(self, value: Optional[Dict[Union[str, bytes], Any]]) -> None: self._delivery_annotations = value @property - def header(self): - # type: () -> Optional[AmqpMessageHeader] + def header(self) -> Optional[AmqpMessageHeader]: """ The message header. - :rtype: Optional[~azure.eventhub.amqp.AmqpMessageHeader] """ return self._header @header.setter - def header(self, value): - # type: (AmqpMessageHeader) -> None + def header(self, value: AmqpMessageHeader) -> None: self._header = value @property - def footer(self): - # type: () -> Optional[Dict] + def footer(self) -> Optional[Dict[Any, Any]]: """ The message footer. @@ -340,8 +270,7 @@ def footer(self): return self._footer @footer.setter - def footer(self, value): - # type: (Dict) -> None + def footer(self, value: Optional[Dict[Any, Any]]) -> None: self._footer = value @@ -350,11 +279,9 @@ class AmqpMessageHeader(DictMixin): The Message header. This is only used on received message, and not set on messages being sent. The properties set on any given message will depend on the Service and not all messages will have all properties. - Please refer to the AMQP spec: http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-messaging-v1.0-os.html#type-header for more information on the message header. - :keyword delivery_count: The number of unsuccessful previous attempts to deliver this message. If this value is non-zero it can be taken as an indication that the delivery might be a duplicate. On first delivery, the value is zero. It is @@ -425,11 +352,9 @@ class AmqpMessageProperties(DictMixin): The properties that are actually used will depend on the service implementation. Not all received messages will have all properties, and not all properties will be utilized on a sent message. - Please refer to the AMQP spec: http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-messaging-v1.0-os.html#type-properties for more information on the message properties. - :keyword message_id: Message-id, if set, uniquely identifies a message within the message system. The message producer is usually responsible for setting the message-id in such a way that it is assured to be globally unique. A broker MAY discard a message as a duplicate if the value diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py new file mode 100644 index 000000000000..4bb676392f89 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py @@ -0,0 +1,27 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +def encode_str(data, encoding='utf-8'): + try: + return data.encode(encoding) + except AttributeError: + return data + +def normalized_data_body(data, **kwargs): + # A helper method to normalize input into AMQP Data Body format + encoding = kwargs.get("encoding", "utf-8") + if isinstance(data, list): + return [encode_str(item, encoding) for item in data] + else: + return [encode_str(data, encoding)] + + +def normalized_sequence_body(sequence): + # A helper method to normalize input into AMQP Sequence Body format + if isinstance(sequence, list) and all([isinstance(b, list) for b in sequence]): + return sequence + elif isinstance(sequence, list): + return [sequence] diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_constants.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_constants.py index 1e2e7d3b6577..576321d4cc2b 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_constants.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_constants.py @@ -4,8 +4,6 @@ # license information. # ------------------------------------------------------------------------- from enum import Enum - -from uamqp import MessageBodyType from azure.core import CaseInsensitiveEnumMeta @@ -13,10 +11,3 @@ class AmqpMessageBodyType(str, Enum, metaclass=CaseInsensitiveEnumMeta): DATA = "data" SEQUENCE = "sequence" VALUE = "value" - - -AMQP_MESSAGE_BODY_TYPE_MAP = { - MessageBodyType.Data.value: AmqpMessageBodyType.DATA, - MessageBodyType.Sequence.value: AmqpMessageBodyType.SEQUENCE, - MessageBodyType.Value.value: AmqpMessageBodyType.VALUE, -} diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py index 6d90033502f8..8001d97cea6d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py @@ -2,39 +2,12 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -import logging import six - -from uamqp import errors, compat - -from ._constants import NO_RETRY_ERRORS - -_LOGGER = logging.getLogger(__name__) - - -def _error_handler(error): - """ - Called internally when an event has failed to send so we - can parse the error to determine whether we should attempt - to retry sending the event again. - Returns the action to take according to error type. - - :param error: The error received in the send attempt. - :type error: Exception - :rtype: ~uamqp.errors.ErrorAction - """ - if error.condition == b"com.microsoft:server-busy": - return errors.ErrorAction(retry=True, backoff=4) - if error.condition == b"com.microsoft:timeout": - return errors.ErrorAction(retry=True, backoff=2) - if error.condition == b"com.microsoft:operation-cancelled": - return errors.ErrorAction(retry=True) - if error.condition == b"com.microsoft:container-close": - return errors.ErrorAction(retry=True, backoff=4) - if error.condition in NO_RETRY_ERRORS: - return errors.ErrorAction(retry=False) - return errors.ErrorAction(retry=True) - +try: + from uamqp import errors, compat +except ImportError: + errors = None + compat = None class EventHubError(Exception): """Represents an error occurred in the client. @@ -126,7 +99,7 @@ class OperationTimeoutError(EventHubError): class OwnershipLostError(Exception): """Raised when `update_checkpoint` detects the ownership to a partition has been lost.""" - +# TODO: delete when async unittests have been refactored def _create_eventhub_exception(exception): if isinstance(exception, errors.AuthenticationException): error = AuthenticationError(str(exception), exception) @@ -150,54 +123,3 @@ def _create_eventhub_exception(exception): else: error = EventHubError(str(exception), exception) return error - - -def _handle_exception( - exception, closable -): # pylint:disable=too-many-branches, too-many-statements - try: # closable is a producer/consumer object - name = closable._name # pylint: disable=protected-access - except AttributeError: # closable is an client object - name = closable._container_id # pylint: disable=protected-access - if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise - _LOGGER.info("%r stops due to keyboard interrupt", name) - closable._close_connection() # pylint:disable=protected-access - raise exception - elif isinstance(exception, EventHubError): - closable._close_handler() # pylint:disable=protected-access - raise exception - elif isinstance( - exception, - ( - errors.MessageAccepted, - errors.MessageAlreadySettled, - errors.MessageModified, - errors.MessageRejected, - errors.MessageReleased, - errors.MessageContentTooLarge, - ), - ): - _LOGGER.info("%r Event data error (%r)", name, exception) - error = EventDataError(str(exception), exception) - raise error - elif isinstance(exception, errors.MessageException): - _LOGGER.info("%r Event data send error (%r)", name, exception) - error = EventDataSendError(str(exception), exception) - raise error - else: - if isinstance(exception, errors.AuthenticationException): - if hasattr(closable, "_close_connection"): - closable._close_connection() # pylint:disable=protected-access - elif isinstance(exception, errors.LinkDetach): - if hasattr(closable, "_close_handler"): - closable._close_handler() # pylint:disable=protected-access - elif isinstance(exception, errors.ConnectionClose): - if hasattr(closable, "_close_connection"): - closable._close_connection() # pylint:disable=protected-access - elif isinstance(exception, errors.MessageHandlerError): - if hasattr(closable, "_close_handler"): - closable._close_handler() # pylint:disable=protected-access - else: # errors.AMQPConnectionError, compat.TimeoutException - if hasattr(closable, "_close_connection"): - closable._close_connection() # pylint:disable=protected-access - return _create_eventhub_exception(exception) From 76e11d14205799d182affe8782c037cd6500e8e8 Mon Sep 17 00:00:00 2001 From: swathipil Date: Mon, 1 Aug 2022 13:48:30 -0700 Subject: [PATCH 03/20] sync tests --- .../azure/eventhub/_client_base.py | 4 +- sdk/eventhub/azure-eventhub/conftest.py | 4 + .../azure-eventhub/dev_requirements.txt | 3 +- sdk/eventhub/azure-eventhub/tests/__init__.py | 4 + .../azure-eventhub/tests/_test_case.py | 7 + .../azure-eventhub/tests/livetest/__init__.py | 4 + .../tests/livetest/synctests/__init__.py | 4 + .../tests/livetest/synctests/test_auth.py | 42 +++- .../synctests/test_buffered_producer.py | 67 ++++-- .../synctests/test_consumer_client.py | 60 +++-- .../tests/livetest/synctests/test_negative.py | 61 +++-- .../livetest/synctests/test_properties.py | 42 +++- .../tests/livetest/synctests/test_receive.py | 70 +++++- .../livetest/synctests/test_reconnect.py | 91 +++++--- .../tests/livetest/synctests/test_send.py | 209 +++++++++++++----- .../azure-eventhub/tests/unittest/__init__.py | 4 + .../tests/unittest/test_event_data.py | 150 ++++++++++--- 17 files changed, 617 insertions(+), 209 deletions(-) create mode 100644 sdk/eventhub/azure-eventhub/tests/__init__.py create mode 100644 sdk/eventhub/azure-eventhub/tests/_test_case.py create mode 100644 sdk/eventhub/azure-eventhub/tests/livetest/__init__.py create mode 100644 sdk/eventhub/azure-eventhub/tests/livetest/synctests/__init__.py create mode 100644 sdk/eventhub/azure-eventhub/tests/unittest/__init__.py diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index 716425eda6d0..3e6e113c9097 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -290,8 +290,8 @@ def __init__( credential: CredentialTypes, **kwargs: Any, ) -> None: - self._uamqp_transport = True - self._amqp_transport = UamqpTransport + self._uamqp_transport = kwargs.pop("uamqp_transport", True) + self._amqp_transport = UamqpTransport if self._uamqp_transport else None self.eventhub_name = eventhub_name if not eventhub_name: diff --git a/sdk/eventhub/azure-eventhub/conftest.py b/sdk/eventhub/azure-eventhub/conftest.py index 802b99d429f9..79862e7a67eb 100644 --- a/sdk/eventhub/azure-eventhub/conftest.py +++ b/sdk/eventhub/azure-eventhub/conftest.py @@ -68,6 +68,10 @@ def get_logger(filename, level=logging.INFO): log = get_logger(None, logging.DEBUG) +@pytest.fixture(scope="session") +def timeout_factor(): + return 1000 # TODO: if pyamqp ReceiveClient is used, set to 1 + @pytest.fixture(scope="session") def resource_group(): try: diff --git a/sdk/eventhub/azure-eventhub/dev_requirements.txt b/sdk/eventhub/azure-eventhub/dev_requirements.txt index df47262912ac..308c4c22ead4 100644 --- a/sdk/eventhub/azure-eventhub/dev_requirements.txt +++ b/sdk/eventhub/azure-eventhub/dev_requirements.txt @@ -4,5 +4,4 @@ azure-mgmt-eventhub==10.0.0 azure-mgmt-resource==20.0.0 aiohttp>=3.0 --e ../../../tools/azure-devtools --e ../../servicebus/azure-servicebus \ No newline at end of file +-e ../../../tools/azure-devtools \ No newline at end of file diff --git a/sdk/eventhub/azure-eventhub/tests/__init__.py b/sdk/eventhub/azure-eventhub/tests/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/_test_case.py b/sdk/eventhub/azure-eventhub/tests/_test_case.py new file mode 100644 index 000000000000..f58d3d758ba1 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/_test_case.py @@ -0,0 +1,7 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ + +def get_decorator(): + return [True] diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/__init__.py b/sdk/eventhub/azure-eventhub/tests/livetest/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/livetest/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/__init__.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py index c00ea84067ea..e8d9c4dd8341 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py @@ -11,20 +11,27 @@ from azure.eventhub import EventData, EventHubProducerClient, EventHubConsumerClient, EventHubSharedKeyCredential from azure.eventhub._client_base import EventHubSASTokenCredential from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_client_secret_credential(live_eventhub): +def test_client_secret_credential(live_eventhub, uamqp_transport): credential = EnvironmentCredential() producer_client = EventHubProducerClient(fully_qualified_namespace=live_eventhub['hostname'], eventhub_name=live_eventhub['event_hub'], credential=credential, - user_agent='customized information') + user_agent='customized information', + uamqp_transport=uamqp_transport) consumer_client = EventHubConsumerClient(fully_qualified_namespace=live_eventhub['hostname'], eventhub_name=live_eventhub['event_hub'], consumer_group='$default', credential=credential, - user_agent='customized information') + user_agent='customized information', + uamqp_transport=uamqp_transport + ) with producer_client: batch = producer_client.create_batch(partition_id='0') batch.add(EventData(body='A single message')) @@ -50,11 +57,15 @@ def on_event(partition_context, event): assert list(on_event.event.body)[0] == 'A single message'.encode('utf-8') +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_client_sas_credential(live_eventhub): +def test_client_sas_credential(live_eventhub, uamqp_transport): # This should "just work" to validate known-good. hostname = live_eventhub['hostname'] - producer_client = EventHubProducerClient.from_connection_string(live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub']) + producer_client = EventHubProducerClient.from_connection_string( + live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub'], uamqp_transport=uamqp_transport + ) with producer_client: batch = producer_client.create_batch(partition_id='0') @@ -67,7 +78,8 @@ def test_client_sas_credential(live_eventhub): token = credential.get_token(auth_uri).token producer_client = EventHubProducerClient(fully_qualified_namespace=hostname, eventhub_name=live_eventhub['event_hub'], - credential=EventHubSASTokenCredential(token, time.time() + 3000)) + credential=EventHubSASTokenCredential(token, time.time() + 3000), + uamqp_transport=uamqp_transport) with producer_client: batch = producer_client.create_batch(partition_id='0') @@ -77,7 +89,8 @@ def test_client_sas_credential(live_eventhub): # Finally let's do it with SAS token + conn str token_conn_str = "Endpoint=sb://{}/;SharedAccessSignature={};".format(hostname, token.decode()) conn_str_producer_client = EventHubProducerClient.from_connection_string(token_conn_str, - eventhub_name=live_eventhub['event_hub']) + eventhub_name=live_eventhub['event_hub'], + uamqp_transport=uamqp_transport) with conn_str_producer_client: batch = conn_str_producer_client.create_batch(partition_id='0') @@ -85,11 +98,15 @@ def test_client_sas_credential(live_eventhub): conn_str_producer_client.send_batch(batch) +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_client_azure_sas_credential(live_eventhub): +def test_client_azure_sas_credential(live_eventhub, uamqp_transport): # This should "just work" to validate known-good. hostname = live_eventhub['hostname'] - producer_client = EventHubProducerClient.from_connection_string(live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub']) + producer_client = EventHubProducerClient.from_connection_string( + live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub'], uamqp_transport=uamqp_transport + ) with producer_client: batch = producer_client.create_batch(partition_id='0') @@ -110,14 +127,17 @@ def test_client_azure_sas_credential(live_eventhub): producer_client.send_batch(batch) +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_client_azure_named_key_credential(live_eventhub): +def test_client_azure_named_key_credential(live_eventhub, uamqp_transport): credential = AzureNamedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) consumer_client = EventHubConsumerClient(fully_qualified_namespace=live_eventhub['hostname'], eventhub_name=live_eventhub['event_hub'], consumer_group='$default', credential=credential, - user_agent='customized information') + user_agent='customized information', + uamqp_transport=uamqp_transport) assert consumer_client.get_eventhub_properties() is not None diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py index 1360e699b8bf..be5801eda051 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py @@ -20,6 +20,9 @@ AmqpAnnotatedMessage, ) from azure.eventhub.exceptions import EventDataSendError, OperationTimeoutError, EventHubError +from ..._test_case import get_decorator + +uamqp_transport_vals = get_decorator() def random_pkey_generation(partitions): @@ -39,26 +42,29 @@ def random_pkey_generation(partitions): return dic +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest() -def test_producer_client_constructor(connection_str): +def test_producer_client_constructor(connection_str, uamqp_transport): def on_success(events, pid): pass def on_error(events, error, pid): pass with pytest.raises(TypeError): - EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True) + EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True, uamqp_transport=uamqp_transport) with pytest.raises(TypeError): - EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True, on_success=on_success) + EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True, on_success=on_success, uamqp_transport=uamqp_transport) with pytest.raises(TypeError): - EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True, on_error=on_error) + EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True, on_error=on_error, uamqp_transport=uamqp_transport) with pytest.raises(ValueError): EventHubProducerClient.from_connection_string( connection_str, buffered_mode=True, on_success=on_success, on_error=on_error, - max_wait_time=0 + max_wait_time=0, + uamqp_transport=uamqp_transport ) with pytest.raises(ValueError): EventHubProducerClient.from_connection_string( @@ -66,11 +72,14 @@ def on_error(events, error, pid): buffered_mode=True, on_success=on_success, on_error=on_error, - max_buffer_length=0 + max_buffer_length=0, + uamqp_transport=uamqp_transport ) @pytest.mark.liveTest +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.parametrize( "flush_after_sending, close_after_sending", [ @@ -80,13 +89,13 @@ def on_error(events, error, pid): ] ) @pytest.mark.liveTest -def test_basic_send_single_events_round_robin(connection_str, flush_after_sending, close_after_sending): +def test_basic_send_single_events_round_robin(connection_str, flush_after_sending, close_after_sending, uamqp_transport): received_events = defaultdict(list) def on_event(partition_context, event): received_events[partition_context.partition_id].append(event) - consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default", uamqp_transport=uamqp_transport) receive_thread = Thread(target=consumer.receive, args=(on_event,)) receive_thread.daemon = True receive_thread.start() @@ -107,7 +116,8 @@ def on_error(events, pid, err): connection_str, buffered_mode=True, on_success=on_success, - on_error=on_error + on_error=on_error, + uamqp_transport=uamqp_transport ) with producer: @@ -177,6 +187,8 @@ def on_error(events, pid, err): @pytest.mark.liveTest +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.parametrize( "flush_after_sending, close_after_sending", [ @@ -185,13 +197,13 @@ def on_error(events, pid, err): (False, False) ] ) -def test_basic_send_batch_events_round_robin(connection_str, flush_after_sending, close_after_sending): +def test_basic_send_batch_events_round_robin(connection_str, flush_after_sending, close_after_sending, uamqp_transport): received_events = defaultdict(list) def on_event(partition_context, event): received_events[partition_context.partition_id].append(event) - consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default", uamqp_transport=uamqp_transport) receive_thread = Thread(target=consumer.receive, args=(on_event,)) receive_thread.daemon = True receive_thread.start() @@ -209,7 +221,8 @@ def on_error(events, pid, err): connection_str, buffered_mode=True, on_success=on_success, - on_error=on_error + on_error=on_error, + uamqp_transport=uamqp_transport ) with producer: @@ -292,13 +305,15 @@ def on_error(events, pid, err): @pytest.mark.liveTest -def test_send_with_hybrid_partition_assignment(connection_str): +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +def test_send_with_hybrid_partition_assignment(connection_str, uamqp_transport): received_events = defaultdict(list) def on_event(partition_context, event): received_events[partition_context.partition_id].append(event) - consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default", uamqp_transport=uamqp_transport) receive_thread = Thread(target=consumer.receive, args=(on_event,)) receive_thread.daemon = True receive_thread.start() @@ -316,7 +331,8 @@ def on_error(events, pid, err): connection_str, buffered_mode=True, on_success=on_success, - on_error=on_error + on_error=on_error, + uamqp_transport=uamqp_transport ) with producer: @@ -381,13 +397,15 @@ def on_error(events, pid, err): receive_thread.join() -def test_send_with_timing_configuration(connection_str): +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +def test_send_with_timing_configuration(connection_str, uamqp_transport): received_events = defaultdict(list) def on_event(partition_context, event): received_events[partition_context.partition_id].append(event) - consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default", uamqp_transport=uamqp_transport) receive_thread = Thread(target=consumer.receive, args=(on_event,)) receive_thread.daemon = True receive_thread.start() @@ -408,7 +426,8 @@ def on_error(events, pid, err): buffered_mode=True, max_wait_time=10, on_success=on_success, - on_error=on_error + on_error=on_error, + uamqp_transport=uamqp_transport ) with producer: @@ -428,7 +447,8 @@ def on_error(events, pid, err): max_wait_time=1000, max_buffer_length=10, on_success=on_success, - on_error=on_error + on_error=on_error, + uamqp_transport=uamqp_transport ) sent_events.clear() @@ -457,13 +477,15 @@ def on_error(events, pid, err): @pytest.mark.liveTest -def test_long_sleep(connection_str): +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +def test_long_sleep(connection_str, uamqp_transport): received_events = defaultdict(list) def on_event(partition_context, event): received_events[partition_context.partition_id].append(event) - consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default", uamqp_transport=uamqp_transport) receive_thread = Thread(target=consumer.receive, args=(on_event,)) receive_thread.daemon = True receive_thread.start() @@ -481,7 +503,8 @@ def on_error(events, pid, err): connection_str, buffered_mode=True, on_success=on_success, - on_error=on_error + on_error=on_error, + uamqp_transport=uamqp_transport ) with producer: diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py index 8da5ddeb6ead..e9320c83aed7 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py @@ -6,14 +6,24 @@ from azure.eventhub import EventHubConsumerClient from azure.eventhub._eventprocessor.in_memory_checkpoint_store import InMemoryCheckpointStore from azure.eventhub._constants import ALL_PARTITIONS +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() + +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_receive_no_partition(connstr_senders): +def test_receive_no_partition(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders senders[0].send(EventData("Test EventData")) senders[1].send(EventData("Test EventData")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', receive_timeout=1) + client = EventHubConsumerClient.from_connection_string( + connection_str, + consumer_group='$default', + receive_timeout=1, + uamqp_transport=uamqp_transport + ) def on_event(partition_context, event): on_event.received += 1 @@ -36,7 +46,7 @@ def on_event(partition_context, event): args=(on_event,), kwargs={"starting_position": "-1"}) worker.start() - time.sleep(10) + time.sleep(20) assert on_event.received == 2 checkpoints = list(client._event_processors.values())[0]._checkpoint_store.list_checkpoints( on_event.namespace, on_event.eventhub_name, on_event.consumer_group @@ -45,11 +55,15 @@ def on_event(partition_context, event): assert len([checkpoint for checkpoint in checkpoints if checkpoint["sequence_number"] == on_event.sequence_number]) > 0 +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_receive_partition(connstr_senders): +def test_receive_partition(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders senders[0].send(EventData("Test EventData")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) def on_event(partition_context, event): on_event.received += 1 @@ -73,17 +87,21 @@ def on_event(partition_context, event): assert on_event.eventhub_name == senders[0]._client.eventhub_name +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_receive_load_balancing(connstr_senders): +def test_receive_load_balancing(connstr_senders, uamqp_transport): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - test code using multiple threads. Sometimes OSX aborts python process") connection_str, senders = connstr_senders cs = InMemoryCheckpointStore() client1 = EventHubConsumerClient.from_connection_string( - connection_str, consumer_group='$default', checkpoint_store=cs, load_balancing_interval=1) + connection_str, consumer_group='$default', checkpoint_store=cs, load_balancing_interval=1, uamqp_transport=uamqp_transport + ) client2 = EventHubConsumerClient.from_connection_string( - connection_str, consumer_group='$default', checkpoint_store=cs, load_balancing_interval=1) + connection_str, consumer_group='$default', checkpoint_store=cs, load_balancing_interval=1, uamqp_transport=uamqp_transport + ) def on_event(partition_context, event): pass @@ -105,13 +123,17 @@ def on_event(partition_context, event): assert len(client2._event_processors[("$default", ALL_PARTITIONS)]._consumers) == 1 -def test_receive_batch_no_max_wait_time(connstr_senders): +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +def test_receive_batch_no_max_wait_time(connstr_senders, uamqp_transport): '''Test whether callback is called when max_wait_time is None and max_batch_size has reached ''' connection_str, senders = connstr_senders senders[0].send(EventData("Test EventData")) senders[1].send(EventData("Test EventData")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) def on_event_batch(partition_context, event_batch): on_event_batch.received += len(event_batch) @@ -133,7 +155,7 @@ def on_event_batch(partition_context, event_batch): worker = threading.Thread(target=client.receive_batch, args=(on_event_batch,), kwargs={"starting_position": "-1"}) worker.start() - time.sleep(10) + time.sleep(20) assert on_event_batch.received == 2 checkpoints = list(client._event_processors.values())[0]._checkpoint_store.list_checkpoints( @@ -146,14 +168,14 @@ def on_event_batch(partition_context, event_batch): worker.join() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.parametrize("max_wait_time, sleep_time, expected_result", [(3, 10, []), - (3, 2, None), - ]) -def test_receive_batch_empty_with_max_wait_time(connection_str, max_wait_time, sleep_time, expected_result): + (3, 2, None)]) +def test_receive_batch_empty_with_max_wait_time(uamqp_transport, connection_str, max_wait_time, sleep_time, expected_result): '''Test whether event handler is called when max_wait_time > 0 and no event is received ''' - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', uamqp_transport=uamqp_transport) def on_event_batch(partition_context, event_batch): on_event_batch.event_batch = event_batch @@ -168,13 +190,17 @@ def on_event_batch(partition_context, event_batch): worker.join() -def test_receive_batch_early_callback(connstr_senders): +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +def test_receive_batch_early_callback(connstr_senders, uamqp_transport): ''' Test whether the callback is called once max_batch_size reaches and before max_wait_time reaches. ''' connection_str, senders = connstr_senders for _ in range(10): senders[0].send(EventData("Test EventData")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) def on_event_batch(partition_context, event_batch): on_event_batch.received += len(event_batch) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py index 3b7249c2cace..d08700873537 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py @@ -19,17 +19,26 @@ ) from azure.eventhub import EventHubConsumerClient from azure.eventhub import EventHubProducerClient +try: + from azure.eventhub._transport._uamqp_transport import UamqpTransport +except (ImportError, ModuleNotFoundError): + UamqpTransport = None +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() + +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_send_batch_with_invalid_hostname(invalid_hostname): +def test_send_batch_with_invalid_hostname(invalid_hostname, uamqp_transport): + amqp_transport = UamqpTransport if uamqp_transport else None if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - it keeps reporting 'Unable to set external certificates' " "and blocking other tests") - client = EventHubProducerClient.from_connection_string(invalid_hostname) + client = EventHubProducerClient.from_connection_string(invalid_hostname, uamqp_transport=uamqp_transport) with client: with pytest.raises(ConnectError): - batch = EventDataBatch() + batch = EventDataBatch(amqp_transport=amqp_transport) batch.add(EventData("test data")) client.send_batch(batch) @@ -40,26 +49,29 @@ def on_error(events, pid, err): on_error.err = err on_error.err = None - client = EventHubProducerClient.from_connection_string(invalid_hostname, on_error=on_error) + client = EventHubProducerClient.from_connection_string(invalid_hostname, on_error=on_error, uamqp_transport=uamqp_transport) with client: - batch = EventDataBatch() + batch = EventDataBatch(amqp_transport=amqp_transport) batch.add(EventData("test data")) client.send_batch(batch) assert isinstance(on_error.err, ConnectError) on_error.err = None - client = EventHubProducerClient.from_connection_string(invalid_hostname, on_error=on_error) + client = EventHubProducerClient.from_connection_string(invalid_hostname, on_error=on_error, uamqp_transport=uamqp_transport) with client: client.send_event(EventData("test data")) assert isinstance(on_error.err, ConnectError) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_receive_with_invalid_hostname_sync(invalid_hostname): +def test_receive_with_invalid_hostname_sync(invalid_hostname, uamqp_transport): def on_event(partition_context, event): pass - client = EventHubConsumerClient.from_connection_string(invalid_hostname, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + invalid_hostname, consumer_group='$default', uamqp_transport=uamqp_transport + ) with client: thread = threading.Thread(target=client.receive, args=(on_event, )) @@ -69,23 +81,26 @@ def on_event(partition_context, event): thread.join() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_send_batch_with_invalid_key(invalid_key): - client = EventHubProducerClient.from_connection_string(invalid_key) +def test_send_batch_with_invalid_key(invalid_key, uamqp_transport): + client = EventHubProducerClient.from_connection_string(invalid_key, uamqp_transport=uamqp_transport) + amqp_transport = UamqpTransport if uamqp_transport else None try: with pytest.raises(ConnectError): - batch = EventDataBatch() + batch = EventDataBatch(amqp_transport=amqp_transport) batch.add(EventData("test data")) client.send_batch(batch) finally: client.close() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_send_batch_to_invalid_partitions(connection_str): +def test_send_batch_to_invalid_partitions(connection_str, uamqp_transport): partitions = ["XYZ", "-1", "1000", "-"] for p in partitions: - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) try: with pytest.raises(ConnectError): batch = client.create_batch(partition_id=p) @@ -95,11 +110,12 @@ def test_send_batch_to_invalid_partitions(connection_str): client.close() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_send_batch_too_large_message(connection_str): +def test_send_batch_too_large_message(connection_str, uamqp_transport): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - open issue regarding message size") - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) try: data = EventData(b"A" * 1100000) batch = client.create_batch() @@ -109,9 +125,10 @@ def test_send_batch_too_large_message(connection_str): client.close() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_send_batch_null_body(connection_str): - client = EventHubProducerClient.from_connection_string(connection_str) +def test_send_batch_null_body(connection_str, uamqp_transport): + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) try: with pytest.raises(ValueError): data = EventData(None) @@ -122,20 +139,22 @@ def test_send_batch_null_body(connection_str): client.close() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_create_batch_with_invalid_hostname_sync(invalid_hostname): +def test_create_batch_with_invalid_hostname_sync(invalid_hostname, uamqp_transport): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - it keeps reporting 'Unable to set external certificates' " "and blocking other tests") - client = EventHubProducerClient.from_connection_string(invalid_hostname) + client = EventHubProducerClient.from_connection_string(invalid_hostname, uamqp_transport=uamqp_transport) with client: with pytest.raises(ConnectError): client.create_batch(max_size_in_bytes=300) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_create_batch_with_too_large_size_sync(connection_str): - client = EventHubProducerClient.from_connection_string(connection_str) +def test_create_batch_with_too_large_size_sync(connection_str, uamqp_transport): + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: with pytest.raises(ValueError): client.create_batch(max_size_in_bytes=5 * 1024 * 1024) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py index eb197eec44b0..678fccabc106 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py @@ -9,60 +9,78 @@ from azure.eventhub import EventHubSharedKeyCredential from azure.eventhub import EventHubConsumerClient from azure.eventhub.exceptions import AuthenticationError, ConnectError, EventHubError +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() + +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_get_properties(live_eventhub): +def test_get_properties(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport + ) with client: properties = client.get_eventhub_properties() assert properties['eventhub_name'] == live_eventhub['event_hub'] and properties['partition_ids'] == ['0', '1'] +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_get_properties_with_auth_error_sync(live_eventhub): +def test_get_properties_with_auth_error_sync(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], "AaBbCcDdEeFf=")) + EventHubSharedKeyCredential(live_eventhub['key_name'], "AaBbCcDdEeFf="), + uamqp_transport=uamqp_transport + ) with client: with pytest.raises(AuthenticationError) as e: client.get_eventhub_properties() client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential("invalid", live_eventhub['access_key']) + EventHubSharedKeyCredential("invalid", live_eventhub['access_key']), uamqp_transport=uamqp_transport ) with client: with pytest.raises(AuthenticationError) as e: client.get_eventhub_properties() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_get_properties_with_connect_error(live_eventhub): +def test_get_properties_with_connect_error(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], "invalid", '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport ) with client: with pytest.raises(ConnectError) as e: client.get_eventhub_properties() client = EventHubConsumerClient("invalid.servicebus.windows.net", live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport ) with client: with pytest.raises(EventHubError) as e: # This can be either ConnectError or ConnectionLostError client.get_eventhub_properties() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_get_partition_ids(live_eventhub): +def test_get_partition_ids(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport + ) with client: partition_ids = client.get_partition_ids() assert partition_ids == ['0', '1'] +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest -def test_get_partition_properties(live_eventhub): +def test_get_partition_properties(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport + ) with client: properties = client.get_partition_properties('0') assert properties['eventhub_name'] == live_eventhub['event_hub'] \ diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py index 21d6e249581e..cec3cb429904 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py @@ -9,13 +9,19 @@ import pytest import time import datetime +import uamqp from azure.eventhub import EventData, TransportType, EventHubConsumerClient from azure.eventhub.exceptions import EventHubError +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() + +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_receive_end_of_stream(connstr_senders): +def test_receive_end_of_stream(connstr_senders, uamqp_transport): def on_event(partition_context, event): if partition_context.partition_id == "0": assert event.body_as_str() == "Receiving only a single event" @@ -29,7 +35,9 @@ def on_event(partition_context, event): assert ", partition_key: 0" in event_str on_event.called = False connection_str, senders = connstr_senders - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) with client: thread = threading.Thread(target=client.receive, args=(on_event,), kwargs={"partition_id": "0", "starting_position": "@latest"}) @@ -43,6 +51,7 @@ def on_event(partition_context, event): thread.join() +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.parametrize("position, inclusive, expected_result", [("offset", False, "Exclusive"), ("offset", True, "Inclusive"), @@ -50,7 +59,7 @@ def on_event(partition_context, event): ("sequence", True, "Inclusive"), ("enqueued_time", False, "Exclusive")]) @pytest.mark.liveTest -def test_receive_with_event_position_sync(connstr_senders, position, inclusive, expected_result): +def test_receive_with_event_position_sync(uamqp_transport, connstr_senders, position, inclusive, expected_result): def on_event(partition_context, event): assert partition_context.last_enqueued_event_properties.get('sequence_number') == event.sequence_number assert partition_context.last_enqueued_event_properties.get('offset') == event.offset @@ -69,7 +78,9 @@ def on_event(partition_context, event): connection_str, senders = connstr_senders senders[0].send(EventData(b"Inclusive")) senders[1].send(EventData(b"Inclusive")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) with client: thread = threading.Thread(target=client.receive, args=(on_event,), kwargs={"starting_position": "-1", @@ -82,7 +93,9 @@ def on_event(partition_context, event): thread.join() senders[0].send(EventData(expected_result)) senders[1].send(EventData(expected_result)) - client2 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client2 = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) with client2: thread = threading.Thread(target=client2.receive, args=(on_event,), kwargs={"starting_position": on_event.event_position, @@ -90,14 +103,44 @@ def on_event(partition_context, event): "track_last_enqueued_event_properties": True}) thread.daemon = True thread.start() - time.sleep(10) + time.sleep(15) assert on_event.event.body_as_str() == expected_result thread.join() - +# TODO: after fixing message property mutability, test +#@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) +#@pytest.mark.liveTest +#def test_receive_modify_message_resend_sync(uamqp_transport, connstr_senders): +# received_modified = [False] +# def on_event(partition_context, event): +# message = event.message +# if message.properties.message_id == b'a1': +# message.properties.message_id = 'a2' +# senders[0].send(event) +# elif message.properties.message_id == b'a2': +# received_modified = [True] +# +# connection_str, senders = connstr_senders +# event = EventData("A", message_id='a1') +# senders[0].send(event) +# client = EventHubConsumerClient.from_connection_string( +# connection_str, consumer_group='$default', uamqp_transport=uamqp_transport +# ) +# with client: +# thread = threading.Thread(target=client.receive, args=(on_event,), +# kwargs={"partition_id": "0", "starting_position": "-1"}) +# thread.daemon = True +# thread.start() +# time.sleep(10) +# assert received_modified[0] +# thread.join() + + +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_receive_owner_level(connstr_senders): +def test_receive_owner_level(connstr_senders, uamqp_transport): def on_event(partition_context, event): pass def on_error(partition_context, error): @@ -105,8 +148,8 @@ def on_error(partition_context, error): on_error.error = None connection_str, senders = connstr_senders - client1 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') - client2 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client1 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', uamqp_transport=uamqp_transport) + client2 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', uamqp_transport=uamqp_transport) with client1, client2: thread1 = threading.Thread(target=client1.receive, args=(on_event,), kwargs={"partition_id": "0", "starting_position": "-1", @@ -128,8 +171,10 @@ def on_error(partition_context, error): assert isinstance(on_error.error, EventHubError) +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_receive_over_websocket_sync(connstr_senders): +def test_receive_over_websocket_sync(connstr_senders, uamqp_transport): app_prop = {"raw_prop": "raw_value"} content_type = "text/plain" message_id_base = "mess_id_sample_" @@ -143,7 +188,8 @@ def on_event(partition_context, event): connection_str, senders = connstr_senders client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', - transport_type=TransportType.AmqpOverWebsocket) + transport_type=TransportType.AmqpOverWebsocket, + uamqp_transport=uamqp_transport) event_list = [] for i in range(5): diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py index 0abfa7a12d2f..afdf608afbd1 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py @@ -7,23 +7,30 @@ import time import pytest -import uamqp -from uamqp import authentication, errors, c_uamqp, compat - - from azure.eventhub import ( EventData, EventHubSharedKeyCredential, EventHubProducerClient, - EventHubConsumerClient + EventHubConsumerClient, ) from azure.eventhub.exceptions import OperationTimeoutError +from azure.eventhub._utils import transform_outbound_single_message +import uamqp +from uamqp import compat +from azure.eventhub._transport._uamqp_transport import UamqpTransport +from ..._test_case import get_decorator + +uamqp_transport_vals = get_decorator() +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_with_long_interval_sync(live_eventhub, sleep): +def test_send_with_long_interval_sync(live_eventhub, sleep, uamqp_transport): test_partition = "0" sender = EventHubProducerClient(live_eventhub['hostname'], live_eventhub['event_hub'], - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + EventHubSharedKeyCredential(live_eventhub['key_name'], + live_eventhub['access_key']), uamqp_transport=uamqp_transport + ) with sender: batch = sender.create_batch(partition_id=test_partition) batch.add(EventData(b"A single event")) @@ -31,7 +38,7 @@ def test_send_with_long_interval_sync(live_eventhub, sleep): if sleep: time.sleep(250) else: - sender._producers[test_partition]._handler._connection._conn.destroy() + sender._producers[test_partition]._handler._connection.close() batch = sender.create_batch(partition_id=test_partition) batch.add(EventData(b"A single event")) sender.send_batch(batch) @@ -39,65 +46,95 @@ def test_send_with_long_interval_sync(live_eventhub, sleep): received = [] uri = "sb://{}/{}".format(live_eventhub['hostname'], live_eventhub['event_hub']) - sas_auth = authentication.SASTokenAuth.from_shared_access_key( - uri, live_eventhub['key_name'], live_eventhub['access_key']) - + sas_auth = SASTokenAuth( + uri, uri, live_eventhub['key_name'], live_eventhub['access_key'] + ) source = "amqps://{}/{}/ConsumerGroups/{}/Partitions/{}".format( live_eventhub['hostname'], live_eventhub['event_hub'], live_eventhub['consumer_group'], test_partition) - receiver = uamqp.ReceiveClient(source, auth=sas_auth, debug=False, timeout=5000, prefetch=500) + receiver = ReceiveClient(live_eventhub['hostname'], source, auth=sas_auth, debug=False, link_credit=500) try: receiver.open() # receive_message_batch() returns immediately once it receives any messages before the max_batch_size # and timeout reach. Could be 1, 2, or any number between 1 and max_batch_size. # So call it twice to ensure the two events are received. - received.extend([EventData._from_message(x) for x in receiver.receive_message_batch(max_batch_size=1, timeout=5000)]) - received.extend([EventData._from_message(x) for x in receiver.receive_message_batch(max_batch_size=1, timeout=5000)]) + received.extend([EventData._from_message(x) for x in receiver.receive_message_batch(max_batch_size=1, timeout=5)]) + received.extend([EventData._from_message(x) for x in receiver.receive_message_batch(max_batch_size=1, timeout=5)]) finally: receiver.close() assert len(received) == 2 assert list(received[0].body)[0] == b"A single event" +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers): +def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(conn_str=connection_str, idle_timeout=10) + amqp_transport = UamqpTransport + retry_total = 3 + # no retry, should just raise error + client = EventHubProducerClient.from_connection_string( + conn_str=connection_str, idle_timeout=10, retry_total=retry_total, uamqp_transport=uamqp_transport + ) with client: ed = EventData('data') sender = client._create_producer(partition_id='0') with sender: - sender._open_with_retry() - time.sleep(11) - sender._unsent_events = [ed.message] - ed.message.on_send_complete = sender._on_outcome + sender._open_with_retry() + time.sleep(11) + ed = transform_outbound_single_message(ed, EventData, amqp_transport.to_outgoing_amqp_message) + sender._unsent_events = [ed._message] + if uamqp_transport: + sender._unsent_events[0].on_send_complete = sender._on_outcome with pytest.raises((uamqp.errors.ConnectionClose, - uamqp.errors.MessageHandlerError, OperationTimeoutError)): - # Mac may raise OperationTimeoutError or MessageHandlerError + uamqp.errors.MessageHandlerError, OperationTimeoutError)): sender._send_event_data() + else: + # for pyamqp add later + pass + if uamqp_transport: sender._send_event_data_with_retry() + if not uamqp_transport: + client = EventHubProducerClient.from_connection_string( + conn_str=connection_str, idle_timeout=10, uamqp_transport=uamqp_transport + ) + with client: + ed = EventData('data') + sender = client._create_producer(partition_id='0') + with sender: + sender._open_with_retry() + time.sleep(11) + ed = transform_outbound_single_message(ed, EventData, amqp_transport.to_outgoing_amqp_message) + sender._unsent_events = [ed._message] + sender._send_event_data() + retry = 0 while retry < 3: try: - messages = receivers[0].receive_message_batch(max_batch_size=10, timeout=10000) + timeout = 10000 if uamqp_transport else 10 + messages = receivers[0].receive_message_batch(max_batch_size=10, timeout=timeout) if messages: received_ed1 = EventData._from_message(messages[0]) assert received_ed1.body_as_str() == 'data' break - except compat.TimeoutException: + except (compat.TimeoutException, TimeoutError): retry += 1 +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_receive_connection_idle_timeout_and_reconnect_sync(connstr_senders): +def test_receive_connection_idle_timeout_and_reconnect_sync(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders client = EventHubConsumerClient.from_connection_string( conn_str=connection_str, consumer_group='$default', - idle_timeout=10 + idle_timeout=10, + uamqp_transport=uamqp_transport ) def on_event_received(event): @@ -112,7 +149,7 @@ def on_event_received(event): senders[0].send(ed) consumer._handler.do_work() - assert consumer._handler._connection._state == c_uamqp.ConnectionState.DISCARDING + assert consumer._handler._connection._state == uamqp.c_uamqp.ConnectionState.DISCARDING duration = 10 now_time = time.time() diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py index 0276de9aae0d..264f63232cce 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py @@ -12,6 +12,7 @@ import sys import uamqp +from uamqp.message import MessageProperties from azure.eventhub import EventData, TransportType, EventDataBatch from azure.eventhub import EventHubProducerClient, EventHubConsumerClient from azure.eventhub.exceptions import EventDataSendError, OperationTimeoutError @@ -21,12 +22,20 @@ AmqpAnnotatedMessage, AmqpMessageProperties, ) +try: + from azure.eventhub._transport._uamqp_transport import UamqpTransport +except (ImportError, ModuleNotFoundError): + UamqpTransport = None +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_with_partition_key(connstr_receivers, live_eventhub): +def test_send_with_partition_key(connstr_receivers, live_eventhub, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: data_val = 0 for partition in [b"a", b"b", b"c", b"d", b"e", b"f"]: @@ -54,7 +63,7 @@ def test_send_with_partition_key(connstr_receivers, live_eventhub): for index, partition in enumerate(receivers): retry_total = 0 while retry_total < 3: - timeout = 5000 + retry_total * 1000 + timeout = (5 + retry_total) * timeout_factor try: received = partition.receive_message_batch(timeout=timeout) for message in received: @@ -97,12 +106,14 @@ def test_send_with_partition_key(connstr_receivers, live_eventhub): assert len(found_partition_keys) == 6 +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_and_receive_large_body_size(connstr_receivers): +def test_send_and_receive_large_body_size(connstr_receivers, uamqp_transport, timeout_factor): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - open issue regarding message size") connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: payload = 250 * 1024 batch = client.create_batch() @@ -111,8 +122,9 @@ def test_send_and_receive_large_body_size(connstr_receivers): client.send_event(EventData("A" * payload)) received = [] + timeout = 10 * timeout_factor for r in receivers: - received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=10000)]) + received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=timeout)]) assert len(received) == 2 assert len(list(received[0].body)[0]) == payload @@ -128,17 +140,19 @@ def test_send_and_receive_large_body_size(connstr_receivers): received = [] for r in receivers: - received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=10000)]) + received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=timeout)]) assert len(received) == 2 assert len(list(received[0].body)[0]) == payload assert len(list(received[1].body)[0]) == payload +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_amqp_annotated_message(connstr_receivers): +def test_send_amqp_annotated_message(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: sequence_body = [b'message', 123.456, True] footer = {'footer_key': 'footer_value'} @@ -174,7 +188,7 @@ def test_send_amqp_annotated_message(connstr_receivers): ) body_ed = """{"json_key": "json_val"}""" - prop_ed = {"raw_prop": "raw_value"} + prop_ed = {b"raw_prop": b"raw_value"} cont_type_ed = "text/plain" corr_id_ed = "corr_id" mess_id_ed = "mess_id" @@ -182,6 +196,7 @@ def test_send_amqp_annotated_message(connstr_receivers): event_data.content_type = cont_type_ed event_data.correlation_id = corr_id_ed event_data.message_id = mess_id_ed + event_data.properties = prop_ed batch = client.create_batch() batch.add(data_message) @@ -216,6 +231,7 @@ def check_values(event): assert event.correlation_id == corr_id_ed assert event.message_id == mess_id_ed assert event.content_type == cont_type_ed + assert event.properties == prop_ed assert event.body_type == AmqpMessageBodyType.DATA received_count["normal_msg"] += 1 elif raw_amqp_message.body_type == AmqpMessageBodyType.SEQUENCE: @@ -238,7 +254,8 @@ def on_event(partition_context, event): on_event.received = [] client = EventHubConsumerClient.from_connection_string(connection_str, - consumer_group='$default') + consumer_group='$default', + uamqp_transport=uamqp_transport) with client: thread = threading.Thread(target=client.receive, args=(on_event,), kwargs={"starting_position": "-1"}) @@ -254,30 +271,35 @@ def on_event(partition_context, event): assert received_count["normal_msg"] == 3 +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.parametrize("payload", - [b"", b"A single event"]) + [(b""), (b"A single event")]) @pytest.mark.liveTest -def test_send_and_receive_small_body(connstr_receivers, payload): +def test_send_and_receive_small_body(connstr_receivers, payload, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: batch = client.create_batch() batch.add(EventData(payload)) client.send_batch(batch) client.send_event(EventData(payload)) received = [] + timeout = 5 * timeout_factor for r in receivers: - received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=5000)]) + received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=timeout)]) assert len(received) == 2 assert list(received[0].body)[0] == payload assert list(received[1].body)[0] == payload +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_partition(connstr_receivers): +def test_send_partition(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + timeout = 5 * timeout_factor + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: batch = client.create_batch() @@ -291,8 +313,8 @@ def test_send_partition(connstr_receivers): client.send_batch(batch) client.send_event(EventData(b"Data"), partition_id="1") - partition_0 = receivers[0].receive_message_batch(timeout=5000) - partition_1 = receivers[1].receive_message_batch(timeout=5000) + partition_0 = receivers[0].receive_message_batch(timeout=timeout) + partition_1 = receivers[1].receive_message_batch(timeout=timeout) assert len(partition_1) >= 2 assert len(partition_0) + len(partition_1) == 4 @@ -309,16 +331,19 @@ def test_send_partition(connstr_receivers): client.send_event(EventData(b"Data"), partition_id="0") time.sleep(5) - partition_0 = receivers[0].receive_message_batch(timeout=5000) - partition_1 = receivers[1].receive_message_batch(timeout=5000) + partition_0 = receivers[0].receive_message_batch(timeout=timeout) + partition_1 = receivers[1].receive_message_batch(timeout=timeout) assert len(partition_0) >= 2 assert len(partition_0) + len(partition_1) == 4 +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_non_ascii(connstr_receivers): +def test_send_non_ascii(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + timeout = 5 * timeout_factor + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: batch = client.create_batch(partition_id="0") batch.add(EventData(u"é,è,à,ù,â,ê,î,ô,û")) @@ -330,8 +355,8 @@ def test_send_non_ascii(connstr_receivers): # receive_message_batch() returns immediately once it receives any messages before the max_batch_size # and timeout reach. Could be 1, 2, or any number between 1 and max_batch_size. # So call it twice to ensure the two events are received. - partition_0 = [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=5000)] + \ - [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=5000)] + partition_0 = [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=timeout)] + \ + [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=timeout)] assert len(partition_0) == 4 assert partition_0[0].body_as_str() == u"é,è,à,ù,â,ê,î,ô,û" assert partition_0[1].body_as_json() == {"foo": u"漢字"} @@ -339,13 +364,16 @@ def test_send_non_ascii(connstr_receivers): assert partition_0[3].body_as_json() == {"foo": u"漢字"} +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_multiple_partitions_with_app_prop(connstr_receivers): +def test_send_multiple_partitions_with_app_prop(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers + timeout = 5 * timeout_factor app_prop_key = "raw_prop" app_prop_value = "raw_value" app_prop = {app_prop_key: app_prop_value} - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: ed0 = EventData(b"Message 0") ed0.properties = app_prop @@ -361,20 +389,25 @@ def test_send_multiple_partitions_with_app_prop(connstr_receivers): client.send_batch(batch) client.send_event(ed1, partition_id="1") - partition_0 = [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=5000)] + partition_0 = [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=timeout)] assert len(partition_0) == 2 assert partition_0[0].properties[b"raw_prop"] == b"raw_value" assert partition_0[1].properties[b"raw_prop"] == b"raw_value" - partition_1 = [EventData._from_message(x) for x in receivers[1].receive_message_batch(timeout=5000)] + partition_1 = [EventData._from_message(x) for x in receivers[1].receive_message_batch(timeout=timeout)] assert len(partition_1) == 2 assert partition_1[0].properties[b"raw_prop"] == b"raw_value" assert partition_1[1].properties[b"raw_prop"] == b"raw_value" +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_over_websocket_sync(connstr_receivers): +def test_send_over_websocket_sync(connstr_receivers, uamqp_transport, timeout_factor): + timeout = 10 * timeout_factor connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) + client = EventHubProducerClient.from_connection_string( + connection_str, transport_type=TransportType.AmqpOverWebsocket, uamqp_transport=uamqp_transport + ) with client: batch = client.create_batch(partition_id="0") @@ -384,17 +417,22 @@ def test_send_over_websocket_sync(connstr_receivers): time.sleep(1) received = [] - received.extend(receivers[0].receive_message_batch(max_batch_size=5, timeout=10000)) + received.extend(receivers[0].receive_message_batch(max_batch_size=5, timeout=timeout)) assert len(received) == 2 +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers): +def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers + timeout = 5 * timeout_factor app_prop_key = "raw_prop" app_prop_value = "raw_value" app_prop = {app_prop_key: app_prop_value} - client = EventHubProducerClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) + client = EventHubProducerClient.from_connection_string( + connection_str, transport_type=TransportType.AmqpOverWebsocket, uamqp_transport=uamqp_transport + ) with client: event_data_batch = client.create_batch(max_size_in_bytes=100000) while True: @@ -407,61 +445,70 @@ def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers): client.send_batch(event_data_batch) received = [] for r in receivers: - received.extend(r.receive_message_batch(timeout=5000)) + received.extend(r.receive_message_batch(timeout=timeout)) assert len(received) >= 1 assert EventData._from_message(received[0]).properties[b"raw_prop"] == b"raw_value" +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_list(connstr_receivers): +def test_send_list(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + timeout = 10 * timeout_factor + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) payload = "A1" with client: client.send_batch([EventData(payload)]) received = [] for r in receivers: - received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=10000)]) + received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=timeout)]) assert len(received) == 1 assert received[0].body_as_str() == payload +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest -def test_send_list_partition(connstr_receivers): +def test_send_list_partition(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + timeout = 10 * timeout_factor + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) payload = "A1" with client: client.send_batch([EventData(payload)], partition_id="0") - message = receivers[0].receive_message_batch(timeout=10000)[0] + message = receivers[0].receive_message_batch(timeout=timeout)[0] received = EventData._from_message(message) assert received.body_as_str() == payload +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.parametrize("to_send, exception_type", [([EventData("A"*1024)]*1100, ValueError), - ("any str", AttributeError) - ]) + ("any str", AttributeError)]) @pytest.mark.liveTest -def test_send_list_wrong_data(connection_str, to_send, exception_type): - client = EventHubProducerClient.from_connection_string(connection_str) +def test_send_list_wrong_data(connection_str, to_send, exception_type, uamqp_transport): + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: with pytest.raises(exception_type): client.send_batch(to_send) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.parametrize("partition_id, partition_key", [("0", None), (None, "pk")]) -def test_send_batch_pid_pk(invalid_hostname, partition_id, partition_key): +def test_send_batch_pid_pk(invalid_hostname, partition_id, partition_key, uamqp_transport): # Use invalid_hostname because this is not a live test. - client = EventHubProducerClient.from_connection_string(invalid_hostname) - batch = EventDataBatch(partition_id=partition_id, partition_key=partition_key) + amqp_transport = UamqpTransport if uamqp_transport else None + client = EventHubProducerClient.from_connection_string(invalid_hostname, uamqp_transport=uamqp_transport) + batch = EventDataBatch(partition_id=partition_id, partition_key=partition_key, amqp_transport=amqp_transport) with client: with pytest.raises(TypeError): client.send_batch(batch, partition_id=partition_id, partition_key=partition_key) -def test_send_with_callback(connstr_receivers): +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) +def test_send_with_callback(connstr_receivers, uamqp_transport): def on_error(events, pid, err): on_error.err = err @@ -472,7 +519,7 @@ def on_success(events, pid): sent_events = [] on_error.err = None connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str, on_success=on_success, on_error=on_error) + client = EventHubProducerClient.from_connection_string(connection_str, on_success=on_success, on_error=on_error, uamqp_transport=uamqp_transport) with client: batch = client.create_batch() @@ -506,3 +553,65 @@ def on_success(events, pid): assert sent_events[-1][1] == "0" assert not on_error.err + +# TODO: add more checks after LegacyMessage has been added +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +@pytest.mark.liveTest +def test_send_message_modify_backcompat(connstr_receivers, uamqp_transport, timeout_factor): + connection_str, receivers = connstr_receivers + if uamqp_transport: + properties = MessageProperties + + timeout = 10 * timeout_factor + outgoing_event_data = EventData(body="hello") + message = outgoing_event_data.message + message.properties = properties(user_id='fake_user') + assert outgoing_event_data.message.properties.user_id == b'fake_user' + assert outgoing_event_data.message.state == uamqp.constants.MessageState.WaitingToBeSent + assert outgoing_event_data.message.delivery_annotations is None + assert outgoing_event_data.message.delivery_no is None + assert outgoing_event_data.message.delivery_tag is None + assert outgoing_event_data.message.on_send_complete is None + assert outgoing_event_data.message.footer is None + assert outgoing_event_data.message.retries == 0 + assert outgoing_event_data.message.idle_time == 0 + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) + with client: + client.send_batch([outgoing_event_data]) + received = [] + for r in receivers: + received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=timeout)]) + + assert len(received) == 1 + received_ed = received[0] + # check that setting properties directly on uamqp message doesn't update the outgoing message from the event data + assert received_ed.message.properties.user_id is None + assert received_ed.message.state == uamqp.constants.MessageState.ReceivedSettled + assert received_ed.message.delivery_annotations is None + assert received_ed.message.delivery_no >= 1 + assert received_ed.message.delivery_tag is None + assert received_ed.message.on_send_complete is None + assert received_ed.message.footer is None + assert received_ed.message.retries >= 0 + assert received_ed.message.idle_time >= 0 + + # setting message properties by calling event data properties SHOULD update the outgoing uamqp message + received_ed.properties = {'prop': 'test'} + received_ed.message_id = "id_message" + received_ed.content_type = "content type" + received_ed.correlation_id = "correlation" + + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) + with client: + client.send_batch([received_ed]) + received = [] + for r in receivers: + received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=timeout)]) + + assert len(received) == 1 + received_ed = received[0] + assert received_ed.message.application_properties == {b"prop": b"test"} + assert received_ed.message_id == "id_message" + assert received_ed.content_type == "content type" + assert received_ed.correlation_id == "correlation" diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/__init__.py b/sdk/eventhub/azure-eventhub/tests/unittest/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/unittest/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py index ef562c2628a5..a5dcf81ef2cc 100644 --- a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py +++ b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py @@ -1,9 +1,25 @@ +# -- coding: utf-8 -- +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- + import platform import pytest -import uamqp from packaging import version -from azure.eventhub.amqp import AmqpAnnotatedMessage +try: + import uamqp + from azure.eventhub._transport._uamqp_transport import UamqpTransport +except ImportError: + UamqpTransport = None + pass +from azure.eventhub.amqp import AmqpAnnotatedMessage, AmqpMessageHeader, AmqpMessageProperties from azure.eventhub import _common +from azure.eventhub._utils import transform_outbound_single_message +from .._test_case import get_decorator + +uamqp_transport_vals = get_decorator() pytestmark = pytest.mark.skipif(platform.python_implementation() == "PyPy", reason="This is ignored for PyPy") @@ -55,23 +71,28 @@ def test_app_properties(): assert event_data.properties["a"] == "b" -def test_sys_properties(): - properties = uamqp.message.MessageProperties() - properties.message_id = "message_id" - properties.user_id = "user_id" - properties.to = "to" - properties.subject = "subject" - properties.reply_to = "reply_to" - properties.correlation_id = "correlation_id" - properties.content_type = "content_type" - properties.content_encoding = "content_encoding" - properties.absolute_expiry_time = 1 - properties.creation_time = 1 - properties.group_id = "group_id" - properties.group_sequence = 1 - properties.reply_to_group_id = "reply_to_group_id" - message = uamqp.Message(properties=properties) - message.annotations = {_common.PROP_OFFSET: "@latest"} +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +def test_sys_properties(uamqp_transport): + if uamqp_transport: + properties = uamqp.message.MessageProperties() + properties.message_id = "message_id" + properties.user_id = "user_id" + properties.to = "to" + properties.subject = "subject" + properties.reply_to = "reply_to" + properties.correlation_id = "correlation_id" + properties.content_type = "content_type" + properties.content_encoding = "content_encoding" + properties.absolute_expiry_time = 1 + properties.creation_time = 1 + properties.group_id = "group_id" + properties.group_sequence = 1 + properties.reply_to_group_id = "reply_to_group_id" + message = uamqp.message.Message(properties=properties) + message.annotations = {_common.PROP_OFFSET: "@latest"} + else: + pass ed = EventData._from_message(message) # type: EventData assert ed.system_properties[_common.PROP_OFFSET] == "@latest" @@ -90,39 +111,102 @@ def test_sys_properties(): assert ed.system_properties[_common.PROP_REPLY_TO_GROUP_ID] == properties.reply_to_group_id -def test_event_data_batch(): - batch = EventDataBatch(max_size_in_bytes=110, partition_key="par") +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +def test_event_data_batch(uamqp_transport): + if uamqp_transport: + amqp_transport = UamqpTransport() + if version.parse(uamqp.__version__) >= version.parse("1.2.8"): + expected_result = 101 + else: + expected_result = 93 + else: + pass + batch = EventDataBatch(max_size_in_bytes=110, partition_key="par", amqp_transport=amqp_transport) batch.add(EventData("A")) assert str(batch) == "EventDataBatch(max_size_in_bytes=110, partition_id=None, partition_key='par', event_count=1)" assert repr(batch) == "EventDataBatch(max_size_in_bytes=110, partition_id=None, partition_key='par', event_count=1)" - # In uamqp v1.2.8, the encoding size of a message has changed. delivery_count in message header is now set to 0 - # instead of None according to the C spec. - # This uamqp change is transparent to EH users so it's not considered as a breaking change. However, it's breaking - # the unit test here. The solution is to add backward compatibility in test. - if version.parse(uamqp.__version__) >= version.parse("1.2.8"): - assert batch.size_in_bytes == 101 and len(batch) == 1 - else: - assert batch.size_in_bytes == 93 and len(batch) == 1 + assert batch.size_in_bytes == expected_result and len(batch) == 1 + with pytest.raises(ValueError): batch.add(EventData("A")) -def test_event_data_from_message(): - message = uamqp.Message('A') + +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) +def test_event_data_from_message(uamqp_transport): + if uamqp_transport: + amqp_transport = UamqpTransport() + else: + pass + annotated_message = AmqpAnnotatedMessage(data_body=b'A') + message = amqp_transport.to_outgoing_amqp_message(annotated_message) event = EventData._from_message(message) assert event.content_type is None assert event.correlation_id is None assert event.message_id is None event.content_type = 'content_type' - event.correlation_id = 'correlation_id' + event.correlation_id = 'correlation_id' event.message_id = 'message_id' assert event.content_type == 'content_type' - assert event.correlation_id == 'correlation_id' + assert event.correlation_id == 'correlation_id' assert event.message_id == 'message_id' + assert list(event.body) == [b'A'] + def test_amqp_message_str_repr(): data_body = b'A' message = AmqpAnnotatedMessage(data_body=data_body) assert str(message) == 'A' assert 'AmqpAnnotatedMessage(body=A, body_type=data' in repr(message) + + +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +def test_amqp_message_from_message(uamqp_transport): + if uamqp_transport: + header = uamqp.message.MessageHeader() + header.delivery_count = 1 + header.time_to_live = 10000 + header.first_acquirer = True + header.durable = True + header.priority = 1 + properties = uamqp.message.MessageProperties() + properties.message_id = "message_id" + properties.user_id = "user_id" + properties.to = "to" + properties.subject = "subject" + properties.reply_to = "reply_to" + properties.correlation_id = "correlation_id" + properties.content_type = "content_type" + properties.content_encoding = "content_encoding" + properties.absolute_expiry_time = 1 + properties.creation_time = 1 + properties.group_id = "group_id" + properties.group_sequence = 1 + properties.reply_to_group_id = "reply_to_group_id" + message = uamqp.message.Message(header=header, properties=properties) + message.annotations = {_common.PROP_OFFSET: "@latest"} + else: + pass + + amqp_message = AmqpAnnotatedMessage(message=message) + assert amqp_message.properties.message_id == message.properties.message_id + assert amqp_message.properties.user_id == message.properties.user_id + assert amqp_message.properties.to == message.properties.to + assert amqp_message.properties.subject == message.properties.subject + assert amqp_message.properties.reply_to == message.properties.reply_to + assert amqp_message.properties.correlation_id == message.properties.correlation_id + assert amqp_message.properties.content_type == message.properties.content_type + assert amqp_message.properties.absolute_expiry_time == message.properties.absolute_expiry_time + assert amqp_message.properties.creation_time == message.properties.creation_time + assert amqp_message.properties.group_id == message.properties.group_id + assert amqp_message.properties.group_sequence == message.properties.group_sequence + assert amqp_message.properties.reply_to_group_id == message.properties.reply_to_group_id + assert amqp_message.header.time_to_live == message.header.ttl + assert amqp_message.header.delivery_count == message.header.delivery_count + assert amqp_message.header.first_acquirer == message.header.first_acquirer + assert amqp_message.header.durable == message.header.durable + assert amqp_message.header.priority == message.header.priority + assert amqp_message.annotations == message.message_annotations From 5622e5bbc79cd8d70987a0ef76a57548b1c56f7b Mon Sep 17 00:00:00 2001 From: swathipil Date: Mon, 1 Aug 2022 13:48:41 -0700 Subject: [PATCH 04/20] update changelog --- sdk/eventhub/azure-eventhub/CHANGELOG.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index d3cf3b66076f..45ec47ad48e6 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -1,5 +1,13 @@ # Release History +## 5.10.1 (Unreleased) + +### Bugs Fixed + +### Other Changes + +- Internal refactoring to support upcoming Pure Python AMQP-based release. + ## 5.10.0 (2022-06-08) ### Features Added From 8de034d307b10e342158a7e7fa1e5ab18082685c Mon Sep 17 00:00:00 2001 From: swathipil Date: Mon, 1 Aug 2022 14:13:34 -0700 Subject: [PATCH 05/20] fix reconn test --- .../azure-eventhub/tests/livetest/synctests/test_reconnect.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py index afdf608afbd1..2bed5ff4a332 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py @@ -46,7 +46,7 @@ def test_send_with_long_interval_sync(live_eventhub, sleep, uamqp_transport): received = [] uri = "sb://{}/{}".format(live_eventhub['hostname'], live_eventhub['event_hub']) - sas_auth = SASTokenAuth( + sas_auth = uamqp.authentication.SASTokenAuth( uri, uri, live_eventhub['key_name'], live_eventhub['access_key'] ) source = "amqps://{}/{}/ConsumerGroups/{}/Partitions/{}".format( @@ -54,7 +54,7 @@ def test_send_with_long_interval_sync(live_eventhub, sleep, uamqp_transport): live_eventhub['event_hub'], live_eventhub['consumer_group'], test_partition) - receiver = ReceiveClient(live_eventhub['hostname'], source, auth=sas_auth, debug=False, link_credit=500) + receiver = uamqp.ReceiveClient(live_eventhub['hostname'], source, auth=sas_auth, debug=False, link_credit=500) try: receiver.open() # receive_message_batch() returns immediately once it receives any messages before the max_batch_size From 9625aade13e673077bc7e7166814db98c701d510 Mon Sep 17 00:00:00 2001 From: swathipil Date: Mon, 1 Aug 2022 19:50:56 -0700 Subject: [PATCH 06/20] add switch to async --- .../azure/eventhub/_client_base.py | 5 +- .../azure-eventhub/azure/eventhub/_common.py | 3 +- .../azure/eventhub/_consumer.py | 4 +- .../azure/eventhub/_producer_client.py | 2 - .../azure-eventhub/azure/eventhub/_utils.py | 24 +- .../_buffered_producer_async.py | 16 +- .../_buffered_producer_dispatcher_async.py | 5 + .../azure/eventhub/aio/_client_base_async.py | 112 +++--- .../eventhub/aio/_connection_manager_async.py | 63 ++++ .../azure/eventhub/aio/_consumer_async.py | 93 ++--- .../eventhub/aio/_consumer_client_async.py | 6 +- .../azure/eventhub/aio/_error_async.py | 74 ---- .../azure/eventhub/aio/_producer_async.py | 105 +++--- .../eventhub/aio/_producer_client_async.py | 13 +- .../azure/eventhub/aio/_transport/__init__.py | 4 + .../eventhub/aio/_transport/_base_async.py | 232 ++++++++++++ .../aio/_transport/_uamqp_transport_async.py | 356 ++++++++++++++++++ .../azure/eventhub/exceptions.py | 25 -- 18 files changed, 829 insertions(+), 313 deletions(-) delete mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/aio/_error_async.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/__init__.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index 3e6e113c9097..f72611edca5a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -194,7 +194,6 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type: (str, Any) -> AccessToken if not scopes: raise ValueError("No token scope provided.") - return _generate_sas_token(scopes[0], self.policy, self.key) @@ -291,7 +290,7 @@ def __init__( **kwargs: Any, ) -> None: self._uamqp_transport = kwargs.pop("uamqp_transport", True) - self._amqp_transport = UamqpTransport if self._uamqp_transport else None + self._amqp_transport = kwargs.pop("amqp_transport", UamqpTransport) self.eventhub_name = eventhub_name if not eventhub_name: @@ -302,7 +301,7 @@ def __init__( if isinstance(credential, AzureSasCredential): self._credential = EventhubAzureSasTokenCredential(credential) elif isinstance(credential, AzureNamedKeyCredential): - self._credential = EventhubAzureNamedKeyTokenCredential(credential) + self._credential = EventhubAzureNamedKeyTokenCredential(credential) # type: ignore else: self._credential = credential # type: ignore self._keep_alive = kwargs.get("keep_alive", 30) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index b302b9c7727e..26f5d42b53cf 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -511,8 +511,9 @@ def __init__( **kwargs, ) -> None: # TODO: this changes API, check with Anna if valid - - # If possible, move out message creation to right before sending. + # Need move out message creation to right before sending. # Might take more time to loop through events and add them all to batch in `send` than in `add` here + # Default async vs sync might cause issues. self._amqp_transport = kwargs.pop("amqp_transport", UamqpTransport) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index 6015bf0186f2..ad647a011a1a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from typing import Deque - from uamqp import ReceiveClient as uamqp_ReceiveClient, Message as uamqp_Message + from uamqp import ReceiveClient as uamqp_ReceiveClient, Message as uamqp_Message, types as uamqp_types from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth from ._consumer_client import EventHubConsumerClient @@ -97,7 +97,7 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs: Any) self._auto_reconnect = auto_reconnect self._retry_policy = self._amqp_transport.create_retry_policy(self._client._config) self._reconnect_backoff = 1 - link_properties: Dict[bytes, int] = {} + link_properties: Dict[uamqp_types.AMQPTypes, uamqp_types.AMQPType] = {} self._error = None self._timeout = 0 self._idle_timeout = (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) if idle_timeout else None diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index 3f00c3d501ae..1fecbd80f978 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -19,8 +19,6 @@ ) from typing_extensions import Literal -from .exceptions import ConnectError, EventHubError -from .amqp import AmqpAnnotatedMessage from ._client_base import ClientBase from ._producer import EventHubProducer from ._constants import ALL_PARTITIONS, MAX_MESSAGE_LENGTH_BYTES diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index 410bd2ff5536..a3a0e25df56a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -43,7 +43,6 @@ # TODO: remove after fixing up async from uamqp import types -from uamqp.message import MessageHeader PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) @@ -132,7 +131,7 @@ def send_context_manager(): else: yield None -# TODO: delete after async unit tests have been refactored + def set_event_partition_key(event, partition_key): # type: (Union[AmqpAnnotatedMessage, EventData], Optional[Union[bytes, str]]) -> None if not partition_key: @@ -155,27 +154,6 @@ def set_event_partition_key(event, partition_key): raw_message.header.durable = True -def set_message_partition_key(message, partition_key): - # type: (Message, Optional[Union[bytes, str]]) -> None - """Set the partition key as an annotation on a uamqp message. - - :param ~uamqp.Message message: The message to update. - :param str partition_key: The partition key value. - :rtype: None - """ - if partition_key: - annotations = message.annotations - if annotations is None: - annotations = dict() - annotations[ - PROP_PARTITION_KEY_AMQP_SYMBOL - ] = partition_key # pylint:disable=protected-access - header = MessageHeader() - header.durable = True - message.annotations = annotations - message.header = header - - @contextmanager def send_context_manager(): span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py index 2d98878d5146..67fee8dd2a58 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import asyncio import logging import queue @@ -14,6 +15,7 @@ from ...exceptions import OperationTimeoutError if TYPE_CHECKING: + from .._transport._base_async import AmqpTransportAsync from ..._producer_client import SendEventTypes _LOGGER = logging.getLogger(__name__) @@ -32,7 +34,8 @@ def __init__( max_message_size_on_link: int, *, max_wait_time: float = 1, - max_buffer_length: int + max_buffer_length: int, + amqp_transport: AmqpTransportAsync ): self._buffered_queue: queue.Queue = queue.Queue() self._max_buffer_len = max_buffer_length @@ -47,11 +50,12 @@ def __init__( self._cur_batch: Optional[EventDataBatch] = None self._max_message_size_on_link = max_message_size_on_link self._check_max_wait_time_future = None + self._amqp_transport = amqp_transport self.partition_id = partition_id async def start(self): async with self._lock: - self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) self._running = True if self._max_wait_time: self._last_send_time = time.time() @@ -113,11 +117,11 @@ async def put_events(self, events, timeout_time=None): self._buffered_queue.put(self._cur_batch) self._buffered_queue.put(events) # create a new batch for incoming events - self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) except ValueError: # add single event exceeds the cur batch size, create new batch self._buffered_queue.put(self._cur_batch) - self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) self._cur_batch.add(events) self._cur_buffered_len += new_events_len @@ -145,7 +149,7 @@ async def _flush(self, timeout_time=None, raise_error=True): _LOGGER.info("Partition: %r started flushing.", self.partition_id) if self._cur_batch: # if there is batch, enqueue it to the buffer first self._buffered_queue.put(self._cur_batch) - self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) while self._cur_buffered_len: remaining_time = timeout_time - time.time() if timeout_time else None if (remaining_time and remaining_time > 0) or remaining_time is None: @@ -187,7 +191,7 @@ async def _flush(self, timeout_time=None, raise_error=True): break # after finishing flushing, reset cur batch and put it into the buffer self._last_send_time = time.time() - self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) _LOGGER.info("Partition %r finished flushing.", self.partition_id) async def check_max_wait_time_worker(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py index ecae49098086..04e5a12ea69f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import asyncio import logging from typing import Dict, List, Callable, Optional, Awaitable, TYPE_CHECKING @@ -13,6 +14,7 @@ from ...exceptions import EventDataSendError, ConnectError, EventHubError if TYPE_CHECKING: + from .._transport._base_async import AmqpTransportAsync from ..._producer_client import SendEventTypes _LOGGER = logging.getLogger(__name__) @@ -33,6 +35,7 @@ def __init__( *, max_buffer_length: int = 1500, max_wait_time: float = 1, + amqp_transport: AmqpTransportAsync, ): self._buffered_producers: Dict[str, BufferedProducer] = {} self._partition_ids: List[str] = partitions @@ -45,6 +48,7 @@ def __init__( self._partition_resolver = PartitionResolver(self._partition_ids) self._max_wait_time = max_wait_time self._max_buffer_length = max_buffer_length + self._amqp_transport = amqp_transport async def _get_partition_id(self, partition_id, partition_key): if partition_id: @@ -77,6 +81,7 @@ async def enqueue_events( self._max_message_size_on_link, max_wait_time=self._max_wait_time, max_buffer_length=self._max_buffer_length, + amqp_transport=self._amqp_transport, ) await buffered_producer.start() self._buffered_producers[pid] = buffered_producer diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py index 5a5a312485c4..b82ae698ed6b 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations import logging import asyncio @@ -39,10 +39,11 @@ MGMT_PARTITION_OPERATION, MGMT_STATUS_CODE, MGMT_STATUS_DESC, + READ_OPERATION, ) from ._async_utils import get_dict_with_loop_if_needed from ._connection_manager_async import get_connection_manager -from ._error_async import _handle_exception +from ._transport._uamqp_transport_async import UamqpTransportAsync if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -211,6 +212,8 @@ def __init__( **kwargs: Any ) -> None: self._internal_kwargs = get_dict_with_loop_if_needed(kwargs.get("loop", None)) + self._uamqp_transport = kwargs.pop("uamqp_transport", True) + self._amqp_transport = UamqpTransportAsync if isinstance(credential, AzureSasCredential): self._credential = EventhubAzureSasTokenCredentialAsync(credential) # type: ignore elif isinstance(credential, AzureNamedKeyCredential): @@ -221,6 +224,8 @@ def __init__( fully_qualified_namespace=fully_qualified_namespace, eventhub_name=eventhub_name, credential=self._credential, + uamqp_transport=self._uamqp_transport, + amqp_transport=self._amqp_transport, **kwargs ) self._conn_manager_async = get_connection_manager(**kwargs) @@ -255,32 +260,19 @@ async def _create_auth_async(self) -> authentication.JWTTokenAsync: except AttributeError: token_type = b"jwt" if token_type == b"servicebus.windows.net:sastoken": - auth = authentication.JWTTokenAsync( - self._auth_uri, + return await self._amqp_transport.create_token_auth( self._auth_uri, functools.partial(self._credential.get_token, self._auth_uri), token_type=token_type, - timeout=self._config.auth_timeout, - http_proxy=self._config.http_proxy, - transport_type=self._config.transport_type, - custom_endpoint_hostname=self._config.custom_endpoint_hostname, - port=self._config.connection_port, - verify=self._config.connection_verify, - refresh_window=300, + config=self._config, + update_token=True, ) - await auth.update_token() - return auth - return authentication.JWTTokenAsync( - self._auth_uri, + return await self._amqp_transport.create_token_auth( self._auth_uri, functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE), token_type=token_type, - timeout=self._config.auth_timeout, - http_proxy=self._config.http_proxy, - transport_type=self._config.transport_type, - custom_endpoint_hostname=self._config.custom_endpoint_hostname, - port=self._config.connection_port, - verify=self._config.connection_verify, + config=self._config, + update_token=False, ) async def _close_connection_async(self) -> None: @@ -322,19 +314,21 @@ async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> last_exception = None while retried_times <= self._config.max_retries: mgmt_auth = await self._create_auth_async() - mgmt_client = AMQPClientAsync( - self._mgmt_target, auth=mgmt_auth, debug=self._config.network_tracing + mgmt_client = self._amqp_transport.create_mgmt_client( + self._address, mgmt_auth=mgmt_auth, config=self._config ) try: - conn = await self._conn_manager_async.get_connection( - self._address.hostname, mgmt_auth - ) - mgmt_msg.application_properties["security_token"] = mgmt_auth.token - await mgmt_client.open_async(connection=conn) - response = await mgmt_client.mgmt_request_async( + await mgmt_client.open_async() + while not await mgmt_client.client_ready_async(): + await asyncio.sleep(0.05) + mgmt_msg.application_properties[ + "security_token" + ] = await self._amqp_transport.get_updated_token(mgmt_auth) + response = await self._amqp_transport.mgmt_client_request( + mgmt_client, mgmt_msg, - constants.READ_OPERATION, - op_type=op_type, + operation=READ_OPERATION, + operation_type=op_type, status_code_field=MGMT_STATUS_CODE, description_fields=MGMT_STATUS_DESC, ) @@ -347,26 +341,23 @@ async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> if status_code < 400: return response if status_code in [401]: - raise errors.AuthenticationException( - "Management authentication failed. Status code: {}, Description: {!r}".format( - status_code, description - ) + raise self._amqp_transport.get_error( + self._amqp_transport.AUTH_EXCEPTION, + f"Management authentication failed. Status code: {status_code}, Description: {description!r}" ) if status_code in [404]: - raise ConnectError( - "Management connection failed. Status code: {}, Description: {!r}".format( - status_code, description - ) - ) - raise errors.AMQPConnectionError( - "Management request error. Status code: {}, Description: {!r}".format( - status_code, description + return self._amqp_transport.get_error( + self._amqp_transport.CONNECTION_ERROR, + f"Management connection failed. Status code: {status_code}, Description: {description!r}" ) + return self._amqp_transport.get_error( + self._amqp_transport.AMQP_CONNECTION_ERROR, + f"Management request error. Status code: {status_code}, Description: {description!r}" ) except asyncio.CancelledError: # pylint: disable=try-except-raise raise except Exception as exception: # pylint:disable=broad-except - last_exception = await _handle_exception(exception, self) + last_exception = await self._amqp_transport._handle_exception(exception, self) # pylint: disable=protected-access await self._backoff_async( retried_times=retried_times, last_exception=last_exception ) @@ -380,12 +371,14 @@ async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> await mgmt_client.close_async() async def _get_eventhub_properties_async(self) -> Dict[str, Any]: - mgmt_msg = Message(application_properties={"name": self.eventhub_name}) + mgmt_msg = mgmt_msg = self._amqp_transport.MESSAGE( + application_properties={"name": self.eventhub_name} + ) response = await self._management_request_async( mgmt_msg, op_type=MGMT_OPERATION ) output = {} - eh_info = response.get_data() # type: Dict[bytes, Any] + eh_info: Dict[bytes, Any] = response.value if eh_info: output["eventhub_name"] = eh_info[b"name"].decode("utf-8") output["created_at"] = utc_from_timestamp( @@ -402,7 +395,7 @@ async def _get_partition_ids_async(self) -> List[str]: async def _get_partition_properties_async( self, partition_id: str ) -> Dict[str, Any]: - mgmt_msg = Message( + mgmt_msg = self._amqp_transport.MESSAGE( application_properties={ "name": self.eventhub_name, "partition": partition_id, @@ -411,7 +404,7 @@ async def _get_partition_properties_async( response = await self._management_request_async( mgmt_msg, op_type=MGMT_PARTITION_OPERATION ) - partition_info = response.get_data() # type: Dict[bytes, Union[bytes, int]] + partition_info = response.value # type: Dict[bytes, Union[bytes, int]] output = {} # type: Dict[str, Any] if partition_info: output["eventhub_name"] = cast(bytes, partition_info[b"name"]).decode( @@ -463,16 +456,12 @@ async def _open(self) -> None: await self._handler.close_async() auth = await self._client._create_auth_async() self._create_handler(auth) - await self._handler.open_async( - connection=await self._client._conn_manager_async.get_connection( - self._client._address.hostname, auth - ) - ) + await self._handler.open_async() while not await self._handler.client_ready_async(): await asyncio.sleep(0.05, **self._internal_kwargs) self._max_message_size_on_link = ( - self._handler.message_handler._link.peer_max_message_size - or constants.MAX_MESSAGE_LENGTH_BYTES + self._amqp_transport.get_remote_max_message_size(self._handler) + or constants.MAX_FRAME_SIZE_BYTES ) self.running = True @@ -487,11 +476,14 @@ async def _close_connection_async(self) -> None: await self._client._conn_manager_async.reset_connection_if_broken() # pylint:disable=protected-access async def _handle_exception(self, exception: Exception) -> Exception: - if not self.running and isinstance(exception, compat.TimeoutException): - exception = errors.AuthenticationException("Authorization timeout.") - return await _handle_exception(exception, self) - - return await _handle_exception(exception, self) + if not self.running and isinstance(exception, self._amqp_transport.TIMEOUT_EXCEPTION): + exception = self._amqp_transport.get_error( + self._amqp_transport.AUTH_EXCEPTION, + "Authorization timeout." + ) + return await self._amqp_transport._handle_exception( # pylint: disable=protected-access + exception, self + ) async def _do_retryable_operation( self, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py index d53443f12fd7..32e544344989 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py @@ -5,8 +5,12 @@ from typing import TYPE_CHECKING +from uamqp import c_uamqp from uamqp.async_ops import ConnectionAsync +from .._connection_manager import _ConnectionMode +from .._constants import TransportType + if TYPE_CHECKING: from uamqp.authentication import JWTTokenAsync @@ -28,6 +32,62 @@ async def reset_connection_if_broken(self) -> None: pass +class _SharedConnectionManager(object): # pylint:disable=too-many-instance-attributes + def __init__(self, **kwargs) -> None: + self._loop = kwargs.get("loop") + self._lock = Lock(loop=self._loop) + self._conn = None + + self._container_id = kwargs.get("container_id") + self._debug = kwargs.get("debug") + self._error_policy = kwargs.get("error_policy") + self._properties = kwargs.get("properties") + self._encoding = kwargs.get("encoding") or "UTF-8" + self._transport_type = kwargs.get("transport_type") or TransportType.Amqp + self._http_proxy = kwargs.get("http_proxy") + self._max_frame_size = kwargs.get("max_frame_size") + self._channel_max = kwargs.get("channel_max") + self._idle_timeout = kwargs.get("idle_timeout") + self._remote_idle_timeout_empty_frame_send_ratio = kwargs.get( + "remote_idle_timeout_empty_frame_send_ratio" + ) + + async def get_connection(self, host: str, auth: "JWTTokenAsync") -> ConnectionAsync: + async with self._lock: + if self._conn is None: + self._conn = ConnectionAsync( + host, + auth, + container_id=self._container_id, + max_frame_size=self._max_frame_size, + channel_max=self._channel_max, + idle_timeout=self._idle_timeout, + properties=self._properties, + remote_idle_timeout_empty_frame_send_ratio=self._remote_idle_timeout_empty_frame_send_ratio, + error_policy=self._error_policy, + debug=self._debug, + loop=self._loop, + encoding=self._encoding, + ) + return self._conn + + async def close_connection(self) -> None: + async with self._lock: + if self._conn: + await self._conn.destroy_async() + self._conn = None + + async def reset_connection_if_broken(self) -> None: + async with self._lock: + if self._conn and self._conn._state in ( # pylint:disable=protected-access + c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member + c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member + c_uamqp.ConnectionState.DISCARDING, # pylint:disable=c-extension-no-member + c_uamqp.ConnectionState.END, # pylint:disable=c-extension-no-member + ): + self._conn = None + + class _SeparateConnectionManager(object): def __init__(self, **kwargs) -> None: pass @@ -43,4 +103,7 @@ async def reset_connection_if_broken(self) -> None: def get_connection_manager(**kwargs) -> "ConnectionManager": + connection_mode = kwargs.get("connection_mode", _ConnectionMode.SeparateConnection) + if connection_mode == _ConnectionMode.ShareConnection: + return _SharedConnectionManager(**kwargs) return _SeparateConnectionManager(**kwargs) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py index afdcad0ad9e2..f3bcb0a36636 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import time import asyncio import uuid @@ -9,19 +10,16 @@ from collections import deque from typing import TYPE_CHECKING, Callable, Awaitable, cast, Dict, Optional, Union, List -import uamqp -from uamqp import errors, types, utils -from uamqp import ReceiveClientAsync, Source - from ._client_base_async import ConsumerProducerMixin from ._async_utils import get_dict_with_loop_if_needed from .._common import EventData -from ..exceptions import _error_handler from .._utils import create_properties, event_position_selector from .._constants import EPOCH_SYMBOL, TIMEOUT_SYMBOL, RECEIVER_RUNTIME_METRIC_SYMBOL if TYPE_CHECKING: from typing import Deque + import uamqp + from uamqp import ReceiveClientAsync, Source, types from uamqp.authentication import JWTTokenAsync from ._consumer_client_async import EventHubConsumerClient @@ -79,9 +77,10 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs) -> N self.running = False self.closed = False - self._on_event_received = kwargs[ + self._amqp_transport = kwargs.pop("amqp_transport") + self._on_event_received: Callable[[Union[Optional[EventData], List[EventData]]], Awaitable[None]] = kwargs[ "on_event_received" - ] # type: Callable[[Union[Optional[EventData], List[EventData]]], Awaitable[None]] + ] self._internal_kwargs = get_dict_with_loop_if_needed(kwargs.get("loop", None)) self._client = client self._source = source @@ -91,81 +90,65 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs) -> N self._owner_level = owner_level self._keep_alive = keep_alive self._auto_reconnect = auto_reconnect - self._retry_policy = errors.ErrorPolicy( - max_retries=self._client._config.max_retries, - on_error=_error_handler, # pylint:disable=protected-access - ) + self._retry_policy = self._amqp_transport.create_retry_policy(self._client._config) self._reconnect_backoff = 1 self._timeout = 0 - self._idle_timeout = (idle_timeout * 1000) if idle_timeout else None - self._link_properties = {} # type: Dict[types.AMQPType, types.AMQPType] + self._idle_timeout = (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) if idle_timeout else None + link_properties: Dict[types.AMQPType, types.AMQPType] = {} partition = self._source.split("/")[-1] self._partition = partition - self._name = "EHReceiver-{}-partition{}".format(uuid.uuid4(), partition) + self._name = f"EHReceiver-{uuid.uuid4()}-partition{partition}" if owner_level is not None: - self._link_properties[types.AMQPSymbol(EPOCH_SYMBOL)] = types.AMQPLong( - int(owner_level) - ) + link_properties[EPOCH_SYMBOL] = int(owner_level) link_property_timeout_ms = ( self._client._config.receive_timeout or self._timeout # pylint:disable=protected-access - ) * 1000 - self._link_properties[types.AMQPSymbol(TIMEOUT_SYMBOL)] = types.AMQPLong( - int(link_property_timeout_ms) - ) - self._handler = None # type: Optional[ReceiveClientAsync] + ) * self._amqp_transport.IDLE_TIMEOUT_FACTOR + link_properties[TIMEOUT_SYMBOL] = int(link_property_timeout_ms) + self._link_properties = self._amqp_transport.create_link_properties(link_properties) + self._handler: Optional[ReceiveClientAsync] = None self._track_last_enqueued_event_properties = ( track_last_enqueued_event_properties ) - self._message_buffer = deque() # type: Deque[uamqp.Message] - self._last_received_event = None # type: Optional[EventData] - - def _create_handler(self, auth: "JWTTokenAsync") -> None: - source = Source(self._source) - if self._offset is not None: - source.set_filter( - event_position_selector(self._offset, self._offset_inclusive) - ) - desired_capabilities = None - if self._track_last_enqueued_event_properties: - symbol_array = [types.AMQPSymbol(RECEIVER_RUNTIME_METRIC_SYMBOL)] - desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) - - properties = create_properties( - self._client._config.user_agent # pylint:disable=protected-access + self._message_buffer: Deque[uamqp.Message] = deque() + self._last_received_event: Optional[EventData] = None + + def _create_handler(self, auth: JWTTokenAsync) -> None: + source = self._amqp_transport.create_source( + self._source, + self._offset, + event_position_selector(self._offset, self._offset_inclusive) ) - self._handler = ReceiveClientAsync( - source, + desired_capabilities = [RECEIVER_RUNTIME_METRIC_SYMBOL] if self._track_last_enqueued_event_properties else None + + self._handler = self._amqp_transport.create_receive_client( + config=self._client._config, # pylint:disable=protected-access + source=source, auth=auth, - debug=self._client._config.network_tracing, # pylint:disable=protected-access - prefetch=self._prefetch, + network_trace=self._client._config.network_tracing, # pylint:disable=protected-access + link_credit=self._prefetch, link_properties=self._link_properties, - timeout=self._timeout, idle_timeout=self._idle_timeout, - error_policy=self._retry_policy, + retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, - receive_settle_mode=uamqp.constants.ReceiverSettleMode.ReceiveAndDelete, - auto_complete=False, - properties=properties, + properties=create_properties( + self._client._config.user_agent, amqp_transport=self._amqp_transport # pylint:disable=protected-access + ), desired_capabilities=desired_capabilities, - **self._internal_kwargs - ) - - self._handler._streaming_receive = True # pylint:disable=protected-access - self._handler._message_received_callback = ( # pylint:disable=protected-access - self._message_received + streaming_receive=True, + message_received_callback=self._message_received, ) async def _open_with_retry(self) -> None: await self._do_retryable_operation(self._open, operation_need_param=False) def _message_received(self, message: uamqp.Message) -> None: - self._message_buffer.appendleft(message) + self._message_buffer.append(message) def _next_message_in_buffer(self): # pylint:disable=protected-access - message = self._message_buffer.pop() + message = self._message_buffer.popleft() event_data = EventData._from_message(message) self._last_received_event = event_data return event_data diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py index 20c3468bd8b1..0cf4634a655b 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import asyncio import logging import datetime @@ -21,7 +22,7 @@ from ._eventprocessor.event_processor import EventProcessor from ._consumer_async import EventHubConsumer from ._client_base_async import ClientBaseAsync -from .._constants import ALL_PARTITIONS +from .._constants import ALL_PARTITIONS, TransportType from .._eventprocessor.common import LoadBalancingStrategy @@ -215,6 +216,7 @@ def _create_consumer( prefetch=prefetch, idle_timeout=self._idle_timeout, track_last_enqueued_event_properties=track_last_enqueued_event_properties, + amqp_transport=self._amqp_transport, **self._internal_kwargs, ) return handler @@ -231,7 +233,7 @@ def from_connection_string( auth_timeout: float = 60, user_agent: Optional[str] = None, retry_total: int = 3, - transport_type: Optional["TransportType"] = None, + transport_type: Optional["TransportType"] = TransportType.Amqp, checkpoint_store: Optional["CheckpointStore"] = None, load_balancing_interval: float = 10, **kwargs: Any diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_error_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_error_async.py deleted file mode 100644 index e272f496ec81..000000000000 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_error_async.py +++ /dev/null @@ -1,74 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- -import asyncio -import logging -from typing import TYPE_CHECKING, Union, cast - -from uamqp import errors - -from ..exceptions import ( - _create_eventhub_exception, - EventHubError, - EventDataSendError, - EventDataError, -) - -if TYPE_CHECKING: - from ._client_base_async import ClientBaseAsync, ConsumerProducerMixin - -_LOGGER = logging.getLogger(__name__) - - -async def _handle_exception( # pylint:disable=too-many-branches, too-many-statements - exception: Exception, closable: Union["ClientBaseAsync", "ConsumerProducerMixin"] -) -> Exception: - # pylint: disable=protected-access - if isinstance(exception, asyncio.CancelledError): - raise exception - error = exception - try: - name = cast("ConsumerProducerMixin", closable)._name - except AttributeError: - name = cast("ClientBaseAsync", closable)._container_id - if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise - _LOGGER.info("%r stops due to keyboard interrupt", name) - await cast("ConsumerProducerMixin", closable)._close_connection_async() - raise error - elif isinstance(exception, EventHubError): - await cast("ConsumerProducerMixin", closable)._close_handler_async() - raise error - elif isinstance( - exception, - ( - errors.MessageAccepted, - errors.MessageAlreadySettled, - errors.MessageModified, - errors.MessageRejected, - errors.MessageReleased, - errors.MessageContentTooLarge, - ), - ): - _LOGGER.info("%r Event data error (%r)", name, exception) - error = EventDataError(str(exception), exception) - raise error - elif isinstance(exception, errors.MessageException): - _LOGGER.info("%r Event data send error (%r)", name, exception) - error = EventDataSendError(str(exception), exception) - raise error - else: - try: - if isinstance(exception, errors.AuthenticationException): - await closable._close_connection_async() - elif isinstance(exception, errors.LinkDetach): - await cast("ConsumerProducerMixin", closable)._close_handler_async() - elif isinstance(exception, errors.ConnectionClose): - await closable._close_connection_async() - elif isinstance(exception, errors.MessageHandlerError): - await cast("ConsumerProducerMixin", closable)._close_handler_async() - else: # errors.AMQPConnectionError, compat.TimeoutException, and any other errors - await closable._close_connection_async() - except AttributeError: - pass - return _create_eventhub_exception(exception) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py index 47db456b0a5f..04d4865ce92f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py @@ -2,23 +2,20 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import uuid import asyncio import logging from typing import Iterable, Union, Optional, Any, AnyStr, List, TYPE_CHECKING import time -from uamqp import types, constants, errors -from uamqp import SendClientAsync - from azure.core.tracing import AbstractSpan from .._common import EventData, EventDataBatch -from ..exceptions import _error_handler, OperationTimeoutError +from ..exceptions import OperationTimeoutError from .._producer import _set_partition_key, _set_trace_message from .._utils import ( create_properties, - set_message_partition_key, trace_message, send_context_manager, transform_outbound_single_message, @@ -29,6 +26,9 @@ from ._async_utils import get_dict_with_loop_if_needed if TYPE_CHECKING: + from uamqp import types, constants, errors + from uamqp import SendClientAsync + from uamqp.authentication import JWTTokenAsync # pylint: disable=ungrouped-imports from ._producer_client_async import EventHubProducerClient @@ -60,8 +60,9 @@ class EventHubProducer( Default value is `True`. """ - def __init__(self, client: "EventHubProducerClient", target: str, **kwargs) -> None: + def __init__(self, client: EventHubProducerClient, target: str, **kwargs) -> None: super().__init__() + self._amqp_transport = kwargs.pop("amqp_transport") partition = kwargs.get("partition", None) send_timeout = kwargs.get("send_timeout", 60) keep_alive = kwargs.get("keep_alive", None) @@ -79,10 +80,14 @@ def __init__(self, client: "EventHubProducerClient", target: str, **kwargs) -> N self._keep_alive = keep_alive self._auto_reconnect = auto_reconnect self._timeout = send_timeout - self._idle_timeout = (idle_timeout * 1000) if idle_timeout else None - self._retry_policy = errors.ErrorPolicy( - max_retries=self._client._config.max_retries, - on_error=_error_handler, # pylint:disable=protected-access + self._idle_timeout = ( + (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) + if idle_timeout + else None + ) + + self._retry_policy = self._amqp_transport.create_retry_policy( + config=self._client._config ) self._reconnect_backoff = 1 self._name = "EHProducer-{}".format(uuid.uuid4()) @@ -91,29 +96,31 @@ def __init__(self, client: "EventHubProducerClient", target: str, **kwargs) -> N if partition: self._target += "/Partitions/" + partition self._name += "-partition{}".format(partition) - self._handler = None # type: Optional[SendClientAsync] - self._outcome = None # type: Optional[constants.MessageSendResult] - self._condition = None # type: Optional[Exception] + self._handler: Optional[SendClientAsync] = None + self._outcome: Optional[constants.MessageSendResult] = None + self._condition: Optional[Exception] = None self._lock = asyncio.Lock(**self._internal_kwargs) - self._link_properties = { - types.AMQPSymbol(TIMEOUT_SYMBOL): types.AMQPLong(int(self._timeout * 1000)) - } + self._link_properties = self._amqp_transport.create_link_properties( + {TIMEOUT_SYMBOL: int(self._timeout * 1000)} + ) + def _create_handler(self, auth: "JWTTokenAsync") -> None: - self._handler = SendClientAsync( - self._target, + self._handler = self._amqp_transport.create_send_client( + config=self._client._config, # pylint:disable=protected-access + target=self._target, auth=auth, - debug=self._client._config.network_tracing, # pylint:disable=protected-access - msg_timeout=self._timeout * 1000, + network_trace=self._client._config.network_tracing, # pylint:disable=protected-access idle_timeout=self._idle_timeout, - error_policy=self._retry_policy, + retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, link_properties=self._link_properties, properties=create_properties( - self._client._config.user_agent # pylint:disable=protected-access + self._client._config.user_agent, # pylint: disable=protected-access + amqp_transport=self._amqp_transport, ), - **self._internal_kwargs + msg_timeout=self._timeout * 1000, ) async def _open_with_retry(self) -> Any: @@ -121,38 +128,15 @@ async def _open_with_retry(self) -> Any: self._open, operation_need_param=False ) - def _set_msg_timeout( - self, timeout_time: Optional[float], last_exception: Optional[Exception] - ) -> None: - if not timeout_time: - return - remaining_time = timeout_time - time.time() - if remaining_time <= 0.0: - if last_exception: - error = last_exception - else: - error = OperationTimeoutError("Send operation timed out") - _LOGGER.info("%r send operation timed out. (%r)", self._name, error) - raise error - self._handler._msg_timeout = remaining_time * 1000 # type: ignore # pylint: disable=protected-access - async def _send_event_data( self, timeout_time: Optional[float] = None, last_exception: Optional[Exception] = None, ) -> None: - # TODO: Correct uAMQP type hints if self._unsent_events: - await self._open() - self._set_msg_timeout(timeout_time, last_exception) - self._handler.queue_message(*self._unsent_events) # type: ignore - await self._handler.wait_async() # type: ignore - self._unsent_events = self._handler.pending_messages # type: ignore - if self._outcome != constants.MessageSendResult.Ok: - if self._outcome == constants.MessageSendResult.Timeout: - self._condition = OperationTimeoutError("Send operation timed out") - if self._condition: - raise self._condition + self._amqp_transport.send_messages( + self, timeout_time, last_exception, _LOGGER + ) async def _send_event_data_with_retry( self, timeout: Optional[float] = None @@ -183,16 +167,20 @@ def _wrap_eventdata( ) -> Union[EventData, EventDataBatch]: if isinstance(event_data, (EventData, AmqpAnnotatedMessage)): outgoing_event_data = transform_outbound_single_message( - event_data, EventData + event_data, EventData, self._amqp_transport.to_outgoing_amqp_message ) if partition_key: - set_message_partition_key(outgoing_event_data.message, partition_key) + self._amqp_transport.set_message_partition_key( + outgoing_event_data._message, partition_key # pylint: disable=protected-access + ) wrapper_event_data = outgoing_event_data trace_message(wrapper_event_data, span) else: if isinstance( event_data, EventDataBatch ): # The partition_key in the param will be omitted. + if not event_data: + return event_data if ( partition_key and partition_key @@ -203,15 +191,16 @@ def _wrap_eventdata( ) for ( event - ) in event_data.message._body_gen: # pylint: disable=protected-access + ) in event_data._message.data: # pylint: disable=protected-access trace_message(event, span) wrapper_event_data = event_data # type:ignore else: if partition_key: - event_data = _set_partition_key(event_data, partition_key) + event_data = _set_partition_key( + event_data, partition_key, self._amqp_transport + ) event_data = _set_trace_message(event_data, span) - wrapper_event_data = EventDataBatch._from_batch(event_data, partition_key) # type: ignore # pylint: disable=protected-access - wrapper_event_data.message.on_send_complete = self._on_outcome + wrapper_event_data = EventDataBatch._from_batch(event_data, self._amqp_transport, partition_key) # type: ignore # pylint: disable=protected-access return wrapper_event_data async def send( @@ -253,7 +242,11 @@ async def send( wrapper_event_data = self._wrap_eventdata( event_data, child, partition_key ) - self._unsent_events = [wrapper_event_data.message] + + if not wrapper_event_data: + return + + self._unsent_events = [wrapper_event_data._message] # pylint: disable=protected-access if child: self._client._add_span_request_attributes( # pylint: disable=protected-access diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py index a425bea6d059..15e7c345f8be 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py @@ -16,7 +16,7 @@ from ._producer_async import EventHubProducer from ._buffered_producer import BufferedProducerDispatcher from .._utils import set_event_partition_key -from .._constants import ALL_PARTITIONS +from .._constants import ALL_PARTITIONS, TransportType from .._common import EventDataBatch, EventData if TYPE_CHECKING: @@ -232,6 +232,7 @@ async def _buffered_send(self, events, **kwargs): self._max_message_size_on_link, max_wait_time=self._max_wait_time, max_buffer_length=self._max_buffer_length, + amqp_transport=self._amqp_transport, ) await self._buffered_producer_dispatcher.enqueue_events(events, **kwargs) @@ -301,9 +302,11 @@ async def _get_max_message_size(self) -> None: EventHubProducer, self._producers[ALL_PARTITIONS] )._open_with_retry() self._max_message_size_on_link = ( - cast( # type: ignore + self._amqp_transport.get_remote_max_message_size( + cast( # type: ignore EventHubProducer, self._producers[ALL_PARTITIONS] - )._handler.message_handler._link.peer_max_message_size + )._handler + ) or constants.MAX_MESSAGE_LENGTH_BYTES ) @@ -350,6 +353,7 @@ def _create_producer( partition=partition_id, send_timeout=send_timeout, idle_timeout=self._idle_timeout, + amqp_transport = self._amqp_transport, **self._internal_kwargs ) return handler @@ -402,7 +406,7 @@ def from_connection_string( auth_timeout: float = 60, user_agent: Optional[str] = None, retry_total: int = 3, - transport_type: Optional["TransportType"] = None, + transport_type: Optional["TransportType"] = TransportType.Amqp, **kwargs: Any ) -> "EventHubProducerClient": """Create an EventHubProducerClient from a connection string. @@ -719,6 +723,7 @@ async def create_batch( max_size_in_bytes=(max_size_in_bytes or self._max_message_size_on_link), partition_id=partition_id, partition_key=partition_key, + amqp_transport=self._amqp_transport, ) return event_data_batch diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/__init__.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py new file mode 100644 index 000000000000..ea36b4288da0 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py @@ -0,0 +1,232 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from abc import ABC, abstractmethod + +class AmqpTransportAsync(ABC): + """ + Abstract class that defines a set of common methods needed by producer and consumer. + """ + # define constants + BATCH_MESSAGE = None + MAX_FRAME_SIZE_BYTES = None + IDLE_TIMEOUT_FACTOR = None + MESSAGE = None + + # define symbols + PRODUCT_SYMBOL = None + VERSION_SYMBOL = None + FRAMEWORK_SYMBOL = None + PLATFORM_SYMBOL = None + USER_AGENT_SYMBOL = None + PROP_PARTITION_KEY_AMQP_SYMBOL = None + + # errors + AMQP_LINK_ERROR = None + LINK_STOLEN_CONDITION = None + MGMT_AUTH_EXCEPTION = None + CONNECTION_ERROR = None + AMQP_CONNECTION_ERROR = None + + @staticmethod + @abstractmethod + def to_outgoing_amqp_message(annotated_message): + """ + Converts an AmqpAnnotatedMessage into an Amqp Message. + :param AmqpAnnotatedMessage annotated_message: AmqpAnnotatedMessage to convert. + :rtype: uamqp.Message or pyamqp.Message + """ + + @staticmethod + async def get_batch_message_encoded_size(message): + """ + Gets the batch message encoded size given an underlying Message. + :param uamqp.BatchMessage message: Message to get encoded size of. + :rtype: int + """ + return await message.gather()[0].get_message_encoded_size() + + @staticmethod + @abstractmethod + async def get_message_encoded_size(message): + """ + Gets the message encoded size given an underlying Message. + :param uamqp.Message or pyamqp.Message message: Message to get encoded size of. + :rtype: int + """ + + @staticmethod + @abstractmethod + async def get_remote_max_message_size(handler): + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + :rtype: int + """ + + @staticmethod + @abstractmethod + async def create_retry_policy(config): + """ + Creates the error retry policy. + :param ~azure.eventhub._configuration.Configuration config: Configuration. + """ + + @staticmethod + @abstractmethod + async def create_link_properties(link_properties): + """ + Creates and returns the link properties. + :param dict[bytes, int] link_properties: The dict of symbols and corresponding values. + :rtype: dict + """ + + @staticmethod + @abstractmethod + async def create_send_client(*, config, **kwargs): + """ + Creates and returns the send client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + + @staticmethod + @abstractmethod + async def send_messages(producer, timeout_time, last_exception, logger): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + + @staticmethod + @abstractmethod + async def set_message_partition_key(message, partition_key, **kwargs): + """Set the partition key as an annotation on a uamqp message. + + :param message: The message to update. + :param str partition_key: The partition key value. + :rtype: None + """ + + @staticmethod + @abstractmethod + async def add_batch(batch_message, outgoing_event_data, event_data): + """ + Add EventData to the data body of the BatchMessage. + :param batch_message: BatchMessage to add data to. + :param outgoing_event_data: Transformed EventData for sending. + :param event_data: EventData to add to internal batch events. uamqp use only. + :rtype: None + """ + + @staticmethod + @abstractmethod + async def create_source(source, offset, selector): + """ + Creates and returns the Source. + + :param str source: Required. + :param int offset: Required. + :param bytes selector: Required. + """ + + @staticmethod + @abstractmethod + async def create_receive_client(*, config, **kwargs): + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword Source source: Required. The source. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. Missing in pyamqp. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + + @staticmethod + @abstractmethod + async def open_receive_client(*, handler, client, auth): + """ + Opens the receive client. + :param ReceiveClient handler: The receive client. + :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. + """ + + @staticmethod + @abstractmethod + async def create_token_auth(auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. + + :keyword bool update_token: Whether to update token. If not updating token, + then pass 300 to refresh_window. Only used by uamqp. + """ + + @staticmethod + @abstractmethod + async def create_mgmt_client(address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ + + @staticmethod + @abstractmethod + async def get_updated_token(mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + + @staticmethod + @abstractmethod + async def mgmt_client_request(mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + + @staticmethod + @abstractmethod + async def get_error(error, message, *, condition=None): + """ + Gets error and passes in error message, and, if applicable, condition. + :param error: The error to raise. + :param str message: Error message. + :param condition: Optional error condition. Will not be used by uamqp. + """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py new file mode 100644 index 000000000000..1c5cdcb4797c --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py @@ -0,0 +1,356 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from __future__ import annotations +import asyncio +import logging +from typing import Optional, Union, Any, cast, TYPE_CHECKING + +try: + from uamqp import ( + BatchMessage, + constants, + MessageBodyType, + Message, + types, + SendClientAsync, + ReceiveClientAsync, + Source, + utils, + authentication, + AMQPClientAsync, + compat, + errors, + ) + from uamqp.message import ( + MessageHeader, + MessageProperties, + ) + uamqp_installed = True +except ImportError: + uamqp_installed = False + +from ._base_async import AmqpTransportAsync +from ...amqp._constants import AmqpMessageBodyType +from ..._constants import ( + NO_RETRY_ERRORS, + PROP_PARTITION_KEY, +) + +from ...exceptions import ( + ConnectError, + EventDataError, + EventDataSendError, + OperationTimeoutError, + EventHubError, + AuthenticationError, + ConnectionLostError, + EventDataError, + EventDataSendError, +) + +if TYPE_CHECKING: + from .._client_base_async import ClientBaseAsync, ConsumerProducerMixin + +_LOGGER = logging.getLogger(__name__) + +if uamqp_installed: + + from ..._transport._uamqp_transport import UamqpTransport + + class UamqpTransportAsync(UamqpTransport, AmqpTransportAsync): + """ + Class which defines uamqp-based methods used by the producer and consumer. + """ + + @staticmethod + async def get_batch_message_encoded_size(message): + """ + Gets the batch message encoded size given an underlying Message. + :param uamqp.BatchMessage message: Message to get encoded size of. + :rtype: int + """ + return await message.gather()[0].get_message_encoded_size() + + @staticmethod + def create_send_client(*, config, **kwargs): # pylint:disable=unused-argument + """ + Creates and returns the uamqp SendClient. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + target = kwargs.pop("target") + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + + return SendClientAsync( + target, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + **kwargs + ) + + @staticmethod + async def send_messages(producer, timeout_time, last_exception, logger): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + # pylint: disable=protected-access + await producer._open() + producer._unsent_events[0].on_send_complete = producer._on_outcome + UamqpTransportAsync._set_msg_timeout(producer, timeout_time, last_exception, logger) + producer._handler.queue_message(*producer._unsent_events) # type: ignore + await producer._handler.wait_async() # type: ignore + producer._unsent_events = producer._handler.pending_messages # type: ignore + if producer._outcome != constants.MessageSendResult.Ok: + if producer._outcome == constants.MessageSendResult.Timeout: + producer._condition = OperationTimeoutError("Send operation timed out") + if producer._condition: + raise producer._condition + + @staticmethod + def set_message_partition_key(message, partition_key, **kwargs): # pylint:disable=unused-argument + # type: (Message, Optional[Union[bytes, str]], Any) -> Message + """Set the partition key as an annotation on a uamqp message. + + :param ~uamqp.Message message: The message to update. + :param str partition_key: The partition key value. + :rtype: Message + """ + if partition_key: + annotations = message.annotations + if annotations is None: + annotations = {} + annotations[ + UamqpTransport.PROP_PARTITION_KEY_AMQP_SYMBOL # TODO: see if setting non-amqp symbol is valid + ] = partition_key + header = MessageHeader() + header.durable = True + message.annotations = annotations + message.header = header + return message + + @staticmethod + async def add_batch(batch_message, outgoing_event_data, event_data): + """ + Add EventData to the data body of the BatchMessage. + :param batch_message: BatchMessage to add data to. + :param outgoing_event_data: Transformed EventData for sending. + :param event_data: EventData to add to internal batch events. uamqp use only. + :rtype: None + """ + # pylint: disable=protected-access + batch_message._internal_events.append(event_data) + batch_message._message._body_gen.append( + outgoing_event_data._message + ) + + @staticmethod + async def create_receive_client(*, config, **kwargs): + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str source: Required. The source. + :keyword str offset: Required. + :keyword str offset_inclusive: Required. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + + source = kwargs.pop("source") + symbol_array = kwargs.pop("desired_capabilities") + desired_capabilities = None + if symbol_array: + symbol_array = [types.AMQPSymbol(symbol) for symbol in symbol_array] + desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + link_credit = kwargs.pop("link_credit") + streaming_receive = kwargs.pop("streaming_receive") + message_received_callback = kwargs.pop("message_received_callback") + + client = ReceiveClientAsync( + source, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + desired_capabilities=desired_capabilities, + prefetch=link_credit, + receive_settle_mode=constants.ReceiverSettleMode.ReceiveAndDelete, + auto_complete=False, + **kwargs + ) + # pylint:disable=protected-access + client._streaming_receive = streaming_receive + client._message_received_callback = (message_received_callback) + return client + + @staticmethod + async def open_receive_client(*, handler, client, auth): + """ + Opens the receive client and returns ready status. + :param ReceiveClient handler: The receive client. + :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. + :param auth: Auth. + :rtype: bool + """ + # pylint:disable=protected-access + await handler.open() + + @staticmethod + async def create_token_auth(auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. + + :keyword bool update_token: Required. Whether to update token. If not updating token, + then pass 300 to refresh_window. + """ + update_token = kwargs.pop("update_token") + refresh_window = 300 + if update_token: + refresh_window = 0 + + token_auth = authentication.JWTTokenAsync( + auth_uri, + auth_uri, + get_token, + token_type=token_type, + timeout=config.auth_timeout, + http_proxy=config.http_proxy, + transport_type=config.transport_type, + custom_endpoint_hostname=config.custom_endpoint_hostname, + port=config.connection_port, + verify=config.connection_verify, + refresh_window=refresh_window + ) + if update_token: + await token_auth.update_token() + return token_auth + + @staticmethod + def create_mgmt_client(address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ + + mgmt_target = f"amqps://{address.hostname}{address.path}" + return AMQPClientAsync( + mgmt_target, + auth=mgmt_auth, + debug=config.network_tracing + ) + + @staticmethod + async def get_updated_token(mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + return mgmt_auth.token + + @staticmethod + async def mgmt_client_request(mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + operation_type = kwargs.pop("operation_type") + operation = kwargs.pop("operation") + return await mgmt_client.mgmt_request_async( + mgmt_msg, + operation, + op_type=operation_type, + **kwargs + ) + + @staticmethod + async def _handle_exception( # pylint:disable=too-many-branches, too-many-statements + exception: Exception, closable: Union["ClientBaseAsync", "ConsumerProducerMixin"] + ) -> Exception: + # pylint: disable=protected-access + if isinstance(exception, asyncio.CancelledError): + raise exception + error = exception + try: + name = cast("ConsumerProducerMixin", closable)._name + except AttributeError: + name = cast("ClientBaseAsync", closable)._container_id + if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise + _LOGGER.info("%r stops due to keyboard interrupt", name) + await cast("ConsumerProducerMixin", closable)._close_connection_async() + raise error + elif isinstance(exception, EventHubError): + await cast("ConsumerProducerMixin", closable)._close_handler_async() + raise error + elif isinstance( + exception, + ( + errors.MessageAccepted, + errors.MessageAlreadySettled, + errors.MessageModified, + errors.MessageRejected, + errors.MessageReleased, + errors.MessageContentTooLarge, + ), + ): + _LOGGER.info("%r Event data error (%r)", name, exception) + error = EventDataError(str(exception), exception) + raise error + elif isinstance(exception, errors.MessageException): + _LOGGER.info("%r Event data send error (%r)", name, exception) + error = EventDataSendError(str(exception), exception) + raise error + else: + try: + if isinstance(exception, errors.AuthenticationException): + await closable._close_connection_async() + elif isinstance(exception, errors.LinkDetach): + await cast("ConsumerProducerMixin", closable)._close_handler_async() + elif isinstance(exception, errors.ConnectionClose): + await closable._close_connection_async() + elif isinstance(exception, errors.MessageHandlerError): + await cast("ConsumerProducerMixin", closable)._close_handler_async() + else: # errors.AMQPConnectionError, compat.TimeoutException, and any other errors + await closable._close_connection_async() + except AttributeError: + pass + return UamqpTransport._create_eventhub_exception(exception) \ No newline at end of file diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py index 8001d97cea6d..a688e70a8861 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py @@ -98,28 +98,3 @@ class OperationTimeoutError(EventHubError): class OwnershipLostError(Exception): """Raised when `update_checkpoint` detects the ownership to a partition has been lost.""" - -# TODO: delete when async unittests have been refactored -def _create_eventhub_exception(exception): - if isinstance(exception, errors.AuthenticationException): - error = AuthenticationError(str(exception), exception) - elif isinstance(exception, errors.VendorLinkDetach): - error = ConnectError(str(exception), exception) - elif isinstance(exception, errors.LinkDetach): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.ConnectionClose): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.MessageHandlerError): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.AMQPConnectionError): - error_type = ( - AuthenticationError - if str(exception).startswith("Unable to open authentication session") - else ConnectError - ) - error = error_type(str(exception), exception) - elif isinstance(exception, compat.TimeoutException): - error = ConnectionLostError(str(exception), exception) - else: - error = EventHubError(str(exception), exception) - return error From 0a6d1b871b86d087be6b2436ea3b36fa58272788 Mon Sep 17 00:00:00 2001 From: swathipil Date: Mon, 1 Aug 2022 20:09:08 -0700 Subject: [PATCH 07/20] fix bugs --- .../azure-eventhub/azure/eventhub/_common.py | 3 +- .../azure/eventhub/aio/_producer_async.py | 2 +- .../eventhub/aio/_transport/_base_async.py | 49 ++++--------------- .../aio/_transport/_uamqp_transport_async.py | 9 ---- 4 files changed, 12 insertions(+), 51 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 26f5d42b53cf..a616197aa870 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -511,9 +511,8 @@ def __init__( **kwargs, ) -> None: # TODO: this changes API, check with Anna if valid - - # Need move out message creation to right before sending. + # Might need move out message creation to right before sending. # Might take more time to loop through events and add them all to batch in `send` than in `add` here - # Default async vs sync might cause issues. self._amqp_transport = kwargs.pop("amqp_transport", UamqpTransport) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py index 04d4865ce92f..7c4cd4be1d62 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py @@ -134,7 +134,7 @@ async def _send_event_data( last_exception: Optional[Exception] = None, ) -> None: if self._unsent_events: - self._amqp_transport.send_messages( + await self._amqp_transport.send_messages( self, timeout_time, last_exception, _LOGGER ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py index ea36b4288da0..e8a60cf4394f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py @@ -39,26 +39,16 @@ def to_outgoing_amqp_message(annotated_message): """ @staticmethod - async def get_batch_message_encoded_size(message): + def get_batch_message_encoded_size(message): """ Gets the batch message encoded size given an underlying Message. :param uamqp.BatchMessage message: Message to get encoded size of. :rtype: int """ - return await message.gather()[0].get_message_encoded_size() @staticmethod @abstractmethod - async def get_message_encoded_size(message): - """ - Gets the message encoded size given an underlying Message. - :param uamqp.Message or pyamqp.Message message: Message to get encoded size of. - :rtype: int - """ - - @staticmethod - @abstractmethod - async def get_remote_max_message_size(handler): + def get_remote_max_message_size(handler): """ Returns max peer message size. :param AMQPClient handler: Client to get remote max message size on link from. @@ -67,7 +57,7 @@ async def get_remote_max_message_size(handler): @staticmethod @abstractmethod - async def create_retry_policy(config): + def create_retry_policy(config): """ Creates the error retry policy. :param ~azure.eventhub._configuration.Configuration config: Configuration. @@ -75,7 +65,7 @@ async def create_retry_policy(config): @staticmethod @abstractmethod - async def create_link_properties(link_properties): + def create_link_properties(link_properties): """ Creates and returns the link properties. :param dict[bytes, int] link_properties: The dict of symbols and corresponding values. @@ -84,7 +74,7 @@ async def create_link_properties(link_properties): @staticmethod @abstractmethod - async def create_send_client(*, config, **kwargs): + def create_send_client(*, config, **kwargs): """ Creates and returns the send client. :param ~azure.eventhub._configuration.Configuration config: The configuration. @@ -113,7 +103,7 @@ async def send_messages(producer, timeout_time, last_exception, logger): @staticmethod @abstractmethod - async def set_message_partition_key(message, partition_key, **kwargs): + def set_message_partition_key(message, partition_key, **kwargs): """Set the partition key as an annotation on a uamqp message. :param message: The message to update. @@ -123,18 +113,7 @@ async def set_message_partition_key(message, partition_key, **kwargs): @staticmethod @abstractmethod - async def add_batch(batch_message, outgoing_event_data, event_data): - """ - Add EventData to the data body of the BatchMessage. - :param batch_message: BatchMessage to add data to. - :param outgoing_event_data: Transformed EventData for sending. - :param event_data: EventData to add to internal batch events. uamqp use only. - :rtype: None - """ - - @staticmethod - @abstractmethod - async def create_source(source, offset, selector): + def create_source(source, offset, selector): """ Creates and returns the Source. @@ -145,7 +124,7 @@ async def create_source(source, offset, selector): @staticmethod @abstractmethod - async def create_receive_client(*, config, **kwargs): + def create_receive_client(*, config, **kwargs): """ Creates and returns the receive client. :param ~azure.eventhub._configuration.Configuration config: The configuration. @@ -192,7 +171,7 @@ async def create_token_auth(auth_uri, get_token, token_type, config, **kwargs): @staticmethod @abstractmethod - async def create_mgmt_client(address, mgmt_auth, config): + def create_mgmt_client(address, mgmt_auth, config): """ Creates and returns the mgmt AMQP client. :param _Address address: Required. The Address. @@ -200,14 +179,6 @@ async def create_mgmt_client(address, mgmt_auth, config): :param ~azure.eventhub._configuration.Configuration config: The configuration. """ - @staticmethod - @abstractmethod - async def get_updated_token(mgmt_auth): - """ - Return updated auth token. - :param mgmt_auth: Auth. - """ - @staticmethod @abstractmethod async def mgmt_client_request(mgmt_client, mgmt_msg, **kwargs): @@ -223,7 +194,7 @@ async def mgmt_client_request(mgmt_client, mgmt_msg, **kwargs): @staticmethod @abstractmethod - async def get_error(error, message, *, condition=None): + def get_error(error, message, *, condition=None): """ Gets error and passes in error message, and, if applicable, condition. :param error: The error to raise. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py index 1c5cdcb4797c..208226eeb67e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py @@ -65,15 +65,6 @@ class UamqpTransportAsync(UamqpTransport, AmqpTransportAsync): Class which defines uamqp-based methods used by the producer and consumer. """ - @staticmethod - async def get_batch_message_encoded_size(message): - """ - Gets the batch message encoded size given an underlying Message. - :param uamqp.BatchMessage message: Message to get encoded size of. - :rtype: int - """ - return await message.gather()[0].get_message_encoded_size() - @staticmethod def create_send_client(*, config, **kwargs): # pylint:disable=unused-argument """ From 4fbe9dbd33832b1d864fd14392fb0c2466cf5968 Mon Sep 17 00:00:00 2001 From: swathipil Date: Tue, 2 Aug 2022 14:09:22 -0700 Subject: [PATCH 08/20] update consumer code to fix tests --- .../azure/eventhub/aio/_consumer_async.py | 55 +-------- .../eventhub/aio/_transport/_base_async.py | 17 ++- .../aio/_transport/_uamqp_transport_async.py | 107 ++++++++++-------- .../tests/livetest/asynctests/__init__.py | 4 + .../livetest/asynctests/test_send_async.py | 11 +- 5 files changed, 89 insertions(+), 105 deletions(-) create mode 100644 sdk/eventhub/azure-eventhub/tests/livetest/asynctests/__init__.py diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py index f3bcb0a36636..c6f375184b09 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py @@ -156,57 +156,4 @@ def _next_message_in_buffer(self): async def receive( self, batch=False, max_batch_size=300, max_wait_time=None ) -> None: - max_retries = ( - self._client._config.max_retries # pylint:disable=protected-access - ) - has_not_fetched_once = True # ensure one trip when max_wait_time is very small - deadline = time.time() + (max_wait_time or 0) # max_wait_time can be None - while len(self._message_buffer) < max_batch_size and ( - time.time() < deadline or has_not_fetched_once - ): - retried_times = 0 - has_not_fetched_once = False - while retried_times <= max_retries: - try: - await self._open() - await cast( - ReceiveClientAsync, self._handler - ).do_work_async() # uamqp sleeps 0.05 if none received - break - except asyncio.CancelledError: # pylint: disable=try-except-raise - raise - except Exception as exception: # pylint: disable=broad-except - if ( - isinstance(exception, uamqp.errors.LinkDetach) - and exception.condition # pylint: disable=no-member - == uamqp.constants.ErrorCodes.LinkStolen - ): - raise await self._handle_exception(exception) - if not self.running: # exit by close - return - if self._last_received_event: - self._offset = self._last_received_event.offset - last_exception = await self._handle_exception(exception) - retried_times += 1 - if retried_times > max_retries: - _LOGGER.info( - "%r operation has exhausted retry. Last exception: %r.", - self._name, - last_exception, - ) - raise last_exception - - if self._message_buffer: - while self._message_buffer: - if batch: - events_for_callback = [] # type: List[EventData] - for _ in range(min(max_batch_size, len(self._message_buffer))): - events_for_callback.append(self._next_message_in_buffer()) - await self._on_event_received(events_for_callback) - else: - await self._on_event_received(self._next_message_in_buffer()) - elif max_wait_time: - if batch: - await self._on_event_received([]) - else: - await self._on_event_received(None) + await self._amqp_transport.receive_messages(self, batch, max_batch_size, max_wait_time) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py index e8a60cf4394f..6c1d7837e3f6 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py @@ -39,6 +39,7 @@ def to_outgoing_amqp_message(annotated_message): """ @staticmethod + @abstractmethod def get_batch_message_encoded_size(message): """ Gets the batch message encoded size given an underlying Message. @@ -147,11 +148,13 @@ def create_receive_client(*, config, **kwargs): @staticmethod @abstractmethod - async def open_receive_client(*, handler, client, auth): + async def receive_messages(handler, batch, max_batch_size, max_wait_time): """ - Opens the receive client. + Receives messages, creates events, and returns them by calling the on received callback. :param ReceiveClient handler: The receive client. - :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. + :param bool batch: If receive batch or single event. + :param int max_batch_size: Max batch size. + :param int or None max_wait_time: Max wait time. """ @staticmethod @@ -179,6 +182,14 @@ def create_mgmt_client(address, mgmt_auth, config): :param ~azure.eventhub._configuration.Configuration config: The configuration. """ + @staticmethod + @abstractmethod + async def get_updated_token(mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + @staticmethod @abstractmethod async def mgmt_client_request(mgmt_client, mgmt_msg, **kwargs): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py index 208226eeb67e..7a57a44ef361 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py @@ -5,8 +5,9 @@ from __future__ import annotations import asyncio +import time import logging -from typing import Optional, Union, Any, cast, TYPE_CHECKING +from typing import Union, cast, TYPE_CHECKING, List try: from uamqp import ( @@ -53,6 +54,7 @@ if TYPE_CHECKING: from .._client_base_async import ClientBaseAsync, ConsumerProducerMixin + from ..._common import EventData _LOGGER = logging.getLogger(__name__) @@ -115,44 +117,7 @@ async def send_messages(producer, timeout_time, last_exception, logger): raise producer._condition @staticmethod - def set_message_partition_key(message, partition_key, **kwargs): # pylint:disable=unused-argument - # type: (Message, Optional[Union[bytes, str]], Any) -> Message - """Set the partition key as an annotation on a uamqp message. - - :param ~uamqp.Message message: The message to update. - :param str partition_key: The partition key value. - :rtype: Message - """ - if partition_key: - annotations = message.annotations - if annotations is None: - annotations = {} - annotations[ - UamqpTransport.PROP_PARTITION_KEY_AMQP_SYMBOL # TODO: see if setting non-amqp symbol is valid - ] = partition_key - header = MessageHeader() - header.durable = True - message.annotations = annotations - message.header = header - return message - - @staticmethod - async def add_batch(batch_message, outgoing_event_data, event_data): - """ - Add EventData to the data body of the BatchMessage. - :param batch_message: BatchMessage to add data to. - :param outgoing_event_data: Transformed EventData for sending. - :param event_data: EventData to add to internal batch events. uamqp use only. - :rtype: None - """ - # pylint: disable=protected-access - batch_message._internal_events.append(event_data) - batch_message._message._body_gen.append( - outgoing_event_data._message - ) - - @staticmethod - async def create_receive_client(*, config, **kwargs): + def create_receive_client(*, config, **kwargs): """ Creates and returns the receive client. :param ~azure.eventhub._configuration.Configuration config: The configuration. @@ -203,16 +168,68 @@ async def create_receive_client(*, config, **kwargs): return client @staticmethod - async def open_receive_client(*, handler, client, auth): + async def receive_messages(handler, batch, max_batch_size, max_wait_time): """ - Opens the receive client and returns ready status. + Receives messages, creates events, and returns them by calling the on received callback. :param ReceiveClient handler: The receive client. - :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. - :param auth: Auth. - :rtype: bool + :param bool batch: If receive batch or single event. + :param int max_batch_size: Max batch size. + :param int or None max_wait_time: Max wait time. """ # pylint:disable=protected-access - await handler.open() + max_retries = ( + handler._client._config.max_retries # pylint:disable=protected-access + ) + has_not_fetched_once = True # ensure one trip when max_wait_time is very small + deadline = time.time() + (max_wait_time or 0) # max_wait_time can be None + while len(handler._message_buffer) < max_batch_size and ( + time.time() < deadline or has_not_fetched_once + ): + retried_times = 0 + has_not_fetched_once = False + while retried_times <= max_retries: + try: + await handler._open() + await cast( + ReceiveClientAsync, handler._handler + ).do_work_async() # uamqp sleeps 0.05 if none received + break + except asyncio.CancelledError: # pylint: disable=try-except-raise + raise + except Exception as exception: # pylint: disable=broad-except + if ( + isinstance(exception, UamqpTransportAsync.AMQP_LINK_ERROR) + and exception.condition == UamqpTransportAsync.LINK_STOLEN_CONDITION # pylint: disable=no-member + ): + raise await handler._handle_exception(exception) + if not handler.running: # exit by close + return + if handler._last_received_event: + handler._offset = handler._last_received_event.offset + last_exception = await handler._handle_exception(exception) + retried_times += 1 + if retried_times > max_retries: + _LOGGER.info( + "%r operation has exhausted retry. Last exception: %r.", + handler._name, + last_exception, + ) + raise last_exception + + if handler._message_buffer: + while handler._message_buffer: + if batch: + events_for_callback: List[EventData] = [] + for _ in range(min(max_batch_size, len(handler._message_buffer))): + events_for_callback.append(handler._next_message_in_buffer()) + await handler._on_event_received(events_for_callback) + else: + await handler._on_event_received(handler._next_message_in_buffer()) + elif max_wait_time: + if batch: + await handler._on_event_received([]) + else: + await handler._on_event_received(None) @staticmethod async def create_token_auth(auth_uri, get_token, token_type, config, **kwargs): diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/__init__.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py index 6b9157d30440..2ff6087c001b 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py @@ -21,12 +21,17 @@ AmqpAnnotatedMessage, AmqpMessageProperties, ) +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() + +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) @pytest.mark.liveTest @pytest.mark.asyncio -async def test_send_amqp_annotated_message(connstr_receivers): +async def test_send_amqp_annotated_message(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) async with client: sequence_body = [b'message', 123.456, True] footer = {'footer_key': 'footer_value'} @@ -126,7 +131,7 @@ async def on_event(partition_context, event): on_event.received = [] client = EventHubConsumerClient.from_connection_string(connection_str, - consumer_group='$default') + consumer_group='$default', uamqp_transport=uamqp_transport) async with client: task = asyncio.ensure_future(client.receive(on_event, starting_position="-1")) await asyncio.sleep(15) From 0711b86e9befefdc78416f462545c2abf9a23fca Mon Sep 17 00:00:00 2001 From: swathipil Date: Tue, 2 Aug 2022 16:57:40 -0700 Subject: [PATCH 09/20] update conftest with uamqp_transport fixture --- sdk/eventhub/azure-eventhub/conftest.py | 10 ++++- sdk/eventhub/azure-eventhub/tests/__init__.py | 4 -- .../azure-eventhub/tests/_test_case.py | 7 ---- .../azure-eventhub/tests/livetest/__init__.py | 4 -- .../tests/livetest/asynctests/__init__.py | 4 -- .../livetest/asynctests/test_send_async.py | 4 -- .../tests/livetest/synctests/__init__.py | 4 -- .../tests/livetest/synctests/test_auth.py | 10 ----- .../synctests/test_buffered_producer.py | 15 ------- .../synctests/test_consumer_client.py | 15 +------ .../tests/livetest/synctests/test_negative.py | 11 ----- .../livetest/synctests/test_properties.py | 8 ---- .../tests/livetest/synctests/test_receive.py | 40 +------------------ .../livetest/synctests/test_reconnect.py | 37 +++++++---------- .../tests/livetest/synctests/test_send.py | 33 ++------------- .../azure-eventhub/tests/unittest/__init__.py | 4 -- .../tests/unittest/test_event_data.py | 11 +---- 17 files changed, 30 insertions(+), 191 deletions(-) delete mode 100644 sdk/eventhub/azure-eventhub/tests/__init__.py delete mode 100644 sdk/eventhub/azure-eventhub/tests/_test_case.py delete mode 100644 sdk/eventhub/azure-eventhub/tests/livetest/__init__.py delete mode 100644 sdk/eventhub/azure-eventhub/tests/livetest/asynctests/__init__.py delete mode 100644 sdk/eventhub/azure-eventhub/tests/livetest/synctests/__init__.py delete mode 100644 sdk/eventhub/azure-eventhub/tests/unittest/__init__.py diff --git a/sdk/eventhub/azure-eventhub/conftest.py b/sdk/eventhub/azure-eventhub/conftest.py index 79862e7a67eb..981fcf68f53a 100644 --- a/sdk/eventhub/azure-eventhub/conftest.py +++ b/sdk/eventhub/azure-eventhub/conftest.py @@ -42,6 +42,9 @@ def sleep(request): sleep = request.config.getoption("--sleep") return sleep.lower() in ('true', 'yes', '1', 'y') +@pytest.fixture(scope="session", params=[True]) +def uamqp_transport(request): + return request.param def get_logger(filename, level=logging.INFO): azure_logger = logging.getLogger("azure.eventhub") @@ -69,8 +72,11 @@ def get_logger(filename, level=logging.INFO): @pytest.fixture(scope="session") -def timeout_factor(): - return 1000 # TODO: if pyamqp ReceiveClient is used, set to 1 +def timeout_factor(uamqp_transport): + if uamqp_transport: + return 1000 + else: + return 1 @pytest.fixture(scope="session") def resource_group(): diff --git a/sdk/eventhub/azure-eventhub/tests/__init__.py b/sdk/eventhub/azure-eventhub/tests/__init__.py deleted file mode 100644 index 34913fb394d7..000000000000 --- a/sdk/eventhub/azure-eventhub/tests/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/_test_case.py b/sdk/eventhub/azure-eventhub/tests/_test_case.py deleted file mode 100644 index f58d3d758ba1..000000000000 --- a/sdk/eventhub/azure-eventhub/tests/_test_case.py +++ /dev/null @@ -1,7 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ - -def get_decorator(): - return [True] diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/__init__.py b/sdk/eventhub/azure-eventhub/tests/livetest/__init__.py deleted file mode 100644 index 34913fb394d7..000000000000 --- a/sdk/eventhub/azure-eventhub/tests/livetest/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/__init__.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/__init__.py deleted file mode 100644 index 34913fb394d7..000000000000 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py index 2ff6087c001b..255818a2511c 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py @@ -21,12 +21,8 @@ AmqpAnnotatedMessage, AmqpMessageProperties, ) -from ..._test_case import get_decorator -uamqp_transport_vals = get_decorator() -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest @pytest.mark.asyncio async def test_send_amqp_annotated_message(connstr_receivers, uamqp_transport): diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/__init__.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/__init__.py deleted file mode 100644 index 34913fb394d7..000000000000 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py index e8d9c4dd8341..3fd512bbf141 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py @@ -11,12 +11,8 @@ from azure.eventhub import EventData, EventHubProducerClient, EventHubConsumerClient, EventHubSharedKeyCredential from azure.eventhub._client_base import EventHubSASTokenCredential from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential -from ..._test_case import get_decorator -uamqp_transport_vals = get_decorator() -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_client_secret_credential(live_eventhub, uamqp_transport): credential = EnvironmentCredential() @@ -57,8 +53,6 @@ def on_event(partition_context, event): assert list(on_event.event.body)[0] == 'A single message'.encode('utf-8') -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_client_sas_credential(live_eventhub, uamqp_transport): # This should "just work" to validate known-good. @@ -98,8 +92,6 @@ def test_client_sas_credential(live_eventhub, uamqp_transport): conn_str_producer_client.send_batch(batch) -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_client_azure_sas_credential(live_eventhub, uamqp_transport): # This should "just work" to validate known-good. @@ -127,8 +119,6 @@ def test_client_azure_sas_credential(live_eventhub, uamqp_transport): producer_client.send_batch(batch) -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_client_azure_named_key_credential(live_eventhub, uamqp_transport): credential = AzureNamedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py index be5801eda051..0fd8e9458417 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py @@ -20,9 +20,6 @@ AmqpAnnotatedMessage, ) from azure.eventhub.exceptions import EventDataSendError, OperationTimeoutError, EventHubError -from ..._test_case import get_decorator - -uamqp_transport_vals = get_decorator() def random_pkey_generation(partitions): @@ -42,8 +39,6 @@ def random_pkey_generation(partitions): return dic -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest() def test_producer_client_constructor(connection_str, uamqp_transport): def on_success(events, pid): @@ -78,8 +73,6 @@ def on_error(events, error, pid): @pytest.mark.liveTest -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.parametrize( "flush_after_sending, close_after_sending", [ @@ -187,8 +180,6 @@ def on_error(events, pid, err): @pytest.mark.liveTest -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.parametrize( "flush_after_sending, close_after_sending", [ @@ -305,8 +296,6 @@ def on_error(events, pid, err): @pytest.mark.liveTest -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) def test_send_with_hybrid_partition_assignment(connection_str, uamqp_transport): received_events = defaultdict(list) @@ -397,8 +386,6 @@ def on_error(events, pid, err): receive_thread.join() -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) def test_send_with_timing_configuration(connection_str, uamqp_transport): received_events = defaultdict(list) @@ -477,8 +464,6 @@ def on_error(events, pid, err): @pytest.mark.liveTest -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) def test_long_sleep(connection_str, uamqp_transport): received_events = defaultdict(list) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py index e9320c83aed7..8b7420075ef4 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py @@ -6,13 +6,8 @@ from azure.eventhub import EventHubConsumerClient from azure.eventhub._eventprocessor.in_memory_checkpoint_store import InMemoryCheckpointStore from azure.eventhub._constants import ALL_PARTITIONS -from ..._test_case import get_decorator -uamqp_transport_vals = get_decorator() - -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_receive_no_partition(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders @@ -55,8 +50,6 @@ def on_event(partition_context, event): assert len([checkpoint for checkpoint in checkpoints if checkpoint["sequence_number"] == on_event.sequence_number]) > 0 -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_receive_partition(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders @@ -87,8 +80,6 @@ def on_event(partition_context, event): assert on_event.eventhub_name == senders[0]._client.eventhub_name -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_receive_load_balancing(connstr_senders, uamqp_transport): if sys.platform.startswith('darwin'): @@ -123,8 +114,6 @@ def on_event(partition_context, event): assert len(client2._event_processors[("$default", ALL_PARTITIONS)]._consumers) == 1 -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) def test_receive_batch_no_max_wait_time(connstr_senders, uamqp_transport): '''Test whether callback is called when max_wait_time is None and max_batch_size has reached ''' @@ -168,7 +157,7 @@ def on_event_batch(partition_context, event_batch): worker.join() -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) + @pytest.mark.parametrize("max_wait_time, sleep_time, expected_result", [(3, 10, []), (3, 2, None)]) @@ -190,8 +179,6 @@ def on_event_batch(partition_context, event_batch): worker.join() -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) def test_receive_batch_early_callback(connstr_senders, uamqp_transport): ''' Test whether the callback is called once max_batch_size reaches and before max_wait_time reaches. ''' diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py index d08700873537..aae3ce3aafd7 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py @@ -23,12 +23,8 @@ from azure.eventhub._transport._uamqp_transport import UamqpTransport except (ImportError, ModuleNotFoundError): UamqpTransport = None -from ..._test_case import get_decorator -uamqp_transport_vals = get_decorator() - -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_send_batch_with_invalid_hostname(invalid_hostname, uamqp_transport): amqp_transport = UamqpTransport if uamqp_transport else None @@ -63,7 +59,6 @@ def on_error(events, pid, err): assert isinstance(on_error.err, ConnectError) -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_receive_with_invalid_hostname_sync(invalid_hostname, uamqp_transport): def on_event(partition_context, event): @@ -81,7 +76,6 @@ def on_event(partition_context, event): thread.join() -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_send_batch_with_invalid_key(invalid_key, uamqp_transport): client = EventHubProducerClient.from_connection_string(invalid_key, uamqp_transport=uamqp_transport) @@ -95,7 +89,6 @@ def test_send_batch_with_invalid_key(invalid_key, uamqp_transport): client.close() -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_send_batch_to_invalid_partitions(connection_str, uamqp_transport): partitions = ["XYZ", "-1", "1000", "-"] @@ -110,7 +103,6 @@ def test_send_batch_to_invalid_partitions(connection_str, uamqp_transport): client.close() -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_send_batch_too_large_message(connection_str, uamqp_transport): if sys.platform.startswith('darwin'): @@ -125,7 +117,6 @@ def test_send_batch_too_large_message(connection_str, uamqp_transport): client.close() -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_send_batch_null_body(connection_str, uamqp_transport): client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) @@ -139,7 +130,6 @@ def test_send_batch_null_body(connection_str, uamqp_transport): client.close() -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_create_batch_with_invalid_hostname_sync(invalid_hostname, uamqp_transport): if sys.platform.startswith('darwin'): @@ -151,7 +141,6 @@ def test_create_batch_with_invalid_hostname_sync(invalid_hostname, uamqp_transpo client.create_batch(max_size_in_bytes=300) -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_create_batch_with_too_large_size_sync(connection_str, uamqp_transport): client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py index 678fccabc106..6a4cb8b6eccf 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py @@ -9,12 +9,8 @@ from azure.eventhub import EventHubSharedKeyCredential from azure.eventhub import EventHubConsumerClient from azure.eventhub.exceptions import AuthenticationError, ConnectError, EventHubError -from ..._test_case import get_decorator -uamqp_transport_vals = get_decorator() - -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_get_properties(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', @@ -25,7 +21,6 @@ def test_get_properties(live_eventhub, uamqp_transport): properties = client.get_eventhub_properties() assert properties['eventhub_name'] == live_eventhub['event_hub'] and properties['partition_ids'] == ['0', '1'] -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_get_properties_with_auth_error_sync(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', @@ -43,7 +38,6 @@ def test_get_properties_with_auth_error_sync(live_eventhub, uamqp_transport): with pytest.raises(AuthenticationError) as e: client.get_eventhub_properties() -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_get_properties_with_connect_error(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], "invalid", '$default', @@ -62,7 +56,6 @@ def test_get_properties_with_connect_error(live_eventhub, uamqp_transport): with pytest.raises(EventHubError) as e: # This can be either ConnectError or ConnectionLostError client.get_eventhub_properties() -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_get_partition_ids(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', @@ -74,7 +67,6 @@ def test_get_partition_ids(live_eventhub, uamqp_transport): assert partition_ids == ['0', '1'] -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_get_partition_properties(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py index cec3cb429904..22133f7983e3 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py @@ -13,13 +13,8 @@ from azure.eventhub import EventData, TransportType, EventHubConsumerClient from azure.eventhub.exceptions import EventHubError -from ..._test_case import get_decorator -uamqp_transport_vals = get_decorator() - -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_receive_end_of_stream(connstr_senders, uamqp_transport): def on_event(partition_context, event): @@ -51,7 +46,6 @@ def on_event(partition_context, event): thread.join() -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.parametrize("position, inclusive, expected_result", [("offset", False, "Exclusive"), ("offset", True, "Inclusive"), @@ -108,37 +102,7 @@ def on_event(partition_context, event): thread.join() -# TODO: after fixing message property mutability, test -#@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) -#@pytest.mark.liveTest -#def test_receive_modify_message_resend_sync(uamqp_transport, connstr_senders): -# received_modified = [False] -# def on_event(partition_context, event): -# message = event.message -# if message.properties.message_id == b'a1': -# message.properties.message_id = 'a2' -# senders[0].send(event) -# elif message.properties.message_id == b'a2': -# received_modified = [True] -# -# connection_str, senders = connstr_senders -# event = EventData("A", message_id='a1') -# senders[0].send(event) -# client = EventHubConsumerClient.from_connection_string( -# connection_str, consumer_group='$default', uamqp_transport=uamqp_transport -# ) -# with client: -# thread = threading.Thread(target=client.receive, args=(on_event,), -# kwargs={"partition_id": "0", "starting_position": "-1"}) -# thread.daemon = True -# thread.start() -# time.sleep(10) -# assert received_modified[0] -# thread.join() - - -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) + @pytest.mark.liveTest def test_receive_owner_level(connstr_senders, uamqp_transport): def on_event(partition_context, event): @@ -171,8 +135,6 @@ def on_error(partition_context, error): assert isinstance(on_error.error, EventHubError) -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_receive_over_websocket_sync(connstr_senders, uamqp_transport): app_prop = {"raw_prop": "raw_value"} diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py index 2bed5ff4a332..c07489acface 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py @@ -18,14 +18,10 @@ import uamqp from uamqp import compat from azure.eventhub._transport._uamqp_transport import UamqpTransport -from ..._test_case import get_decorator -uamqp_transport_vals = get_decorator() -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest -def test_send_with_long_interval_sync(live_eventhub, sleep, uamqp_transport): +def test_send_with_long_interval_sync(live_eventhub, sleep, uamqp_transport, timeout_factor): test_partition = "0" sender = EventHubProducerClient(live_eventhub['hostname'], live_eventhub['event_hub'], EventHubSharedKeyCredential(live_eventhub['key_name'], @@ -38,7 +34,10 @@ def test_send_with_long_interval_sync(live_eventhub, sleep, uamqp_transport): if sleep: time.sleep(250) else: - sender._producers[test_partition]._handler._connection.close() + if uamqp_transport: + sender._producers[test_partition]._handler._connection._conn.destroy() + else: + pass batch = sender.create_batch(partition_id=test_partition) batch.add(EventData(b"A single event")) sender.send_batch(batch) @@ -46,38 +45,35 @@ def test_send_with_long_interval_sync(live_eventhub, sleep, uamqp_transport): received = [] uri = "sb://{}/{}".format(live_eventhub['hostname'], live_eventhub['event_hub']) - sas_auth = uamqp.authentication.SASTokenAuth( - uri, uri, live_eventhub['key_name'], live_eventhub['access_key'] - ) + if uamqp_transport: + sas_auth = uamqp.authentication.SASTokenAuth.from_shared_access_key( + uri, live_eventhub['key_name'], live_eventhub['access_key']) source = "amqps://{}/{}/ConsumerGroups/{}/Partitions/{}".format( live_eventhub['hostname'], live_eventhub['event_hub'], live_eventhub['consumer_group'], test_partition) - receiver = uamqp.ReceiveClient(live_eventhub['hostname'], source, auth=sas_auth, debug=False, link_credit=500) + if uamqp_transport: + receiver = uamqp.ReceiveClient(source, auth=sas_auth, debug=False, timeout=5000, prefetch=500) try: receiver.open() # receive_message_batch() returns immediately once it receives any messages before the max_batch_size # and timeout reach. Could be 1, 2, or any number between 1 and max_batch_size. # So call it twice to ensure the two events are received. - received.extend([EventData._from_message(x) for x in receiver.receive_message_batch(max_batch_size=1, timeout=5)]) - received.extend([EventData._from_message(x) for x in receiver.receive_message_batch(max_batch_size=1, timeout=5)]) + received.extend([EventData._from_message(x) for x in receiver.receive_message_batch(max_batch_size=1, timeout=5 * timeout_factor)]) + received.extend([EventData._from_message(x) for x in receiver.receive_message_batch(max_batch_size=1, timeout=5 * timeout_factor)]) finally: receiver.close() assert len(received) == 2 assert list(received[0].body)[0] == b"A single event" -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest -def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers, uamqp_transport): +def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers amqp_transport = UamqpTransport - retry_total = 3 - # no retry, should just raise error client = EventHubProducerClient.from_connection_string( - conn_str=connection_str, idle_timeout=10, retry_total=retry_total, uamqp_transport=uamqp_transport + conn_str=connection_str, idle_timeout=10, uamqp_transport=uamqp_transport ) with client: ed = EventData('data') @@ -115,8 +111,7 @@ def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers, uamq retry = 0 while retry < 3: try: - timeout = 10000 if uamqp_transport else 10 - messages = receivers[0].receive_message_batch(max_batch_size=10, timeout=timeout) + messages = receivers[0].receive_message_batch(max_batch_size=10, timeout=10 * timeout_factor) if messages: received_ed1 = EventData._from_message(messages[0]) assert received_ed1.body_as_str() == 'data' @@ -125,8 +120,6 @@ def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers, uamq retry += 1 -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_receive_connection_idle_timeout_and_reconnect_sync(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py index 264f63232cce..fdc3b640791b 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py @@ -26,12 +26,8 @@ from azure.eventhub._transport._uamqp_transport import UamqpTransport except (ImportError, ModuleNotFoundError): UamqpTransport = None -from ..._test_case import get_decorator -uamqp_transport_vals = get_decorator() -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_send_with_partition_key(connstr_receivers, live_eventhub, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers @@ -106,8 +102,6 @@ def test_send_with_partition_key(connstr_receivers, live_eventhub, uamqp_transpo assert len(found_partition_keys) == 6 -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_send_and_receive_large_body_size(connstr_receivers, uamqp_transport, timeout_factor): if sys.platform.startswith('darwin'): @@ -147,8 +141,6 @@ def test_send_and_receive_large_body_size(connstr_receivers, uamqp_transport, ti assert len(list(received[1].body)[0]) == payload -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_send_amqp_annotated_message(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers @@ -271,9 +263,8 @@ def on_event(partition_context, event): assert received_count["normal_msg"] == 3 -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.parametrize("payload", - [(b""), (b"A single event")]) + [b"", b"A single event"]) @pytest.mark.liveTest def test_send_and_receive_small_body(connstr_receivers, payload, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers @@ -293,8 +284,6 @@ def test_send_and_receive_small_body(connstr_receivers, payload, uamqp_transport assert list(received[1].body)[0] == payload -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_send_partition(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers @@ -337,8 +326,6 @@ def test_send_partition(connstr_receivers, uamqp_transport, timeout_factor): assert len(partition_0) + len(partition_1) == 4 -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_send_non_ascii(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers @@ -364,8 +351,6 @@ def test_send_non_ascii(connstr_receivers, uamqp_transport, timeout_factor): assert partition_0[3].body_as_json() == {"foo": u"漢字"} -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_send_multiple_partitions_with_app_prop(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers @@ -399,8 +384,6 @@ def test_send_multiple_partitions_with_app_prop(connstr_receivers, uamqp_transpo assert partition_1[1].properties[b"raw_prop"] == b"raw_value" -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_send_over_websocket_sync(connstr_receivers, uamqp_transport, timeout_factor): timeout = 10 * timeout_factor @@ -421,8 +404,6 @@ def test_send_over_websocket_sync(connstr_receivers, uamqp_transport, timeout_fa assert len(received) == 2 -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers @@ -450,8 +431,6 @@ def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers, uamq assert EventData._from_message(received[0]).properties[b"raw_prop"] == b"raw_value" -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_send_list(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers @@ -468,8 +447,6 @@ def test_send_list(connstr_receivers, uamqp_transport, timeout_factor): assert received[0].body_as_str() == payload -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_send_list_partition(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers @@ -483,7 +460,7 @@ def test_send_list_partition(connstr_receivers, uamqp_transport, timeout_factor) assert received.body_as_str() == payload -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) + @pytest.mark.parametrize("to_send, exception_type", [([EventData("A"*1024)]*1100, ValueError), ("any str", AttributeError)]) @@ -495,7 +472,7 @@ def test_send_list_wrong_data(connection_str, to_send, exception_type, uamqp_tra client.send_batch(to_send) -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) + @pytest.mark.parametrize("partition_id, partition_key", [("0", None), (None, "pk")]) def test_send_batch_pid_pk(invalid_hostname, partition_id, partition_key, uamqp_transport): # Use invalid_hostname because this is not a live test. @@ -507,7 +484,7 @@ def test_send_batch_pid_pk(invalid_hostname, partition_id, partition_key, uamqp_ client.send_batch(batch, partition_id=partition_id, partition_key=partition_key) -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) + def test_send_with_callback(connstr_receivers, uamqp_transport): def on_error(events, pid, err): @@ -555,8 +532,6 @@ def on_success(events, pid): assert not on_error.err # TODO: add more checks after LegacyMessage has been added -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) @pytest.mark.liveTest def test_send_message_modify_backcompat(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/__init__.py b/sdk/eventhub/azure-eventhub/tests/unittest/__init__.py deleted file mode 100644 index 34913fb394d7..000000000000 --- a/sdk/eventhub/azure-eventhub/tests/unittest/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py index a5dcf81ef2cc..cf7982ba4992 100644 --- a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py +++ b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py @@ -17,9 +17,6 @@ from azure.eventhub.amqp import AmqpAnnotatedMessage, AmqpMessageHeader, AmqpMessageProperties from azure.eventhub import _common from azure.eventhub._utils import transform_outbound_single_message -from .._test_case import get_decorator - -uamqp_transport_vals = get_decorator() pytestmark = pytest.mark.skipif(platform.python_implementation() == "PyPy", reason="This is ignored for PyPy") @@ -71,8 +68,6 @@ def test_app_properties(): assert event_data.properties["a"] == "b" -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) def test_sys_properties(uamqp_transport): if uamqp_transport: properties = uamqp.message.MessageProperties() @@ -111,8 +106,6 @@ def test_sys_properties(uamqp_transport): assert ed.system_properties[_common.PROP_REPLY_TO_GROUP_ID] == properties.reply_to_group_id -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) def test_event_data_batch(uamqp_transport): if uamqp_transport: amqp_transport = UamqpTransport() @@ -133,7 +126,7 @@ def test_event_data_batch(uamqp_transport): batch.add(EventData("A")) -@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) + def test_event_data_from_message(uamqp_transport): if uamqp_transport: amqp_transport = UamqpTransport() @@ -162,8 +155,6 @@ def test_amqp_message_str_repr(): assert 'AmqpAnnotatedMessage(body=A, body_type=data' in repr(message) -@pytest.mark.parametrize("uamqp_transport", - uamqp_transport_vals) def test_amqp_message_from_message(uamqp_transport): if uamqp_transport: header = uamqp.message.MessageHeader() From 28a14a05972eb645b9e57d46dc7dfb0c58024b1a Mon Sep 17 00:00:00 2001 From: swathipil Date: Tue, 2 Aug 2022 16:58:08 -0700 Subject: [PATCH 10/20] lint + mypy --- .../azure/eventhub/_client_base.py | 8 +- .../azure/eventhub/_consumer.py | 6 +- .../azure/eventhub/_producer_client.py | 2 - .../eventhub/_transport/_uamqp_transport.py | 4 +- .../azure-eventhub/azure/eventhub/_utils.py | 19 +- .../azure/eventhub/aio/_client_base_async.py | 16 +- .../eventhub/aio/_connection_manager_async.py | 1 + .../azure/eventhub/aio/_consumer_async.py | 4 +- .../eventhub/aio/_consumer_client_async.py | 3 +- .../azure/eventhub/aio/_producer_async.py | 8 +- .../eventhub/aio/_producer_client_async.py | 3 +- .../eventhub/aio/_transport/_base_async.py | 8 +- .../aio/_transport/_uamqp_transport_async.py | 603 +++++++++--------- .../azure/eventhub/amqp/_amqp_message.py | 5 +- .../azure/eventhub/amqp/_amqp_utils.py | 6 +- .../azure/eventhub/exceptions.py | 5 - 16 files changed, 321 insertions(+), 380 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index f72611edca5a..b1ec0651cf51 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -13,6 +13,7 @@ from datetime import timedelta from urllib.parse import urlparse, quote_plus import six +from uamqp import utils from azure.core.credentials import ( AccessToken, @@ -23,7 +24,6 @@ from azure.core.pipeline.policies import RetryMode -from uamqp import utils from ._transport._uamqp_transport import UamqpTransport from .exceptions import ClientClosedError from ._configuration import Configuration @@ -47,7 +47,7 @@ CredentialTypes = Union[ AzureSasCredential, AzureNamedKeyCredential, - EventHubSharedKeyCredential, + "EventHubSharedKeyCredential", TokenCredential, ] @@ -439,9 +439,9 @@ def _management_request( f"Management request error. Status code: {status_code}, Description: {description!r}" ) except Exception as exception: # pylint: disable=broad-except - last_exception = self._amqp_transport._handle_exception( + last_exception = self._amqp_transport._handle_exception( # pylint: disable=protected-access exception, self - ) # pylint: disable=protected-access + ) self._backoff( retried_times=retried_times, last_exception=last_exception ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index ad647a011a1a..98d3fe21c86b 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -3,13 +3,12 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- from __future__ import unicode_literals, annotations -from multiprocessing import Event import time import uuid import logging from collections import deque -from typing import TYPE_CHECKING, Callable, Dict, Optional, Any, Deque, Union +from typing import TYPE_CHECKING, Callable, Dict, Optional, Any, Deque from ._common import EventData from ._client_base import ConsumerProducerMixin @@ -21,7 +20,6 @@ ) if TYPE_CHECKING: - from typing import Deque from uamqp import ReceiveClient as uamqp_ReceiveClient, Message as uamqp_Message, types as uamqp_types from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth @@ -97,7 +95,7 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs: Any) self._auto_reconnect = auto_reconnect self._retry_policy = self._amqp_transport.create_retry_policy(self._client._config) self._reconnect_backoff = 1 - link_properties: Dict[uamqp_types.AMQPTypes, uamqp_types.AMQPType] = {} + link_properties: Dict[uamqp_types.AMQPType, uamqp_types.AMQPType] = {} self._error = None self._timeout = 0 self._idle_timeout = (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) if idle_timeout else None diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index 1fecbd80f978..38536c9483e9 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -23,8 +23,6 @@ from ._producer import EventHubProducer from ._constants import ALL_PARTITIONS, MAX_MESSAGE_LENGTH_BYTES from ._common import EventDataBatch, EventData -from ._constants import ALL_PARTITIONS -from ._producer import EventHubProducer from ._buffered_producer import BufferedProducerDispatcher from ._utils import set_event_partition_key from .amqp import AmqpAnnotatedMessage diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index d928620675c0..a3ef0d18f52a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -40,8 +40,6 @@ from ..exceptions import ( ConnectError, - EventDataError, - EventDataSendError, OperationTimeoutError, EventHubError, AuthenticationError, @@ -322,7 +320,7 @@ def create_source(source, offset, selector): return source @staticmethod - def create_receive_client(*, config, **kwargs): + def create_receive_client(*, config, **kwargs): # pylint: disable=unused-argument """ Creates and returns the receive client. :param ~azure.eventhub._configuration.Configuration config: The configuration. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index a3a0e25df56a..de4502be188a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -24,6 +24,7 @@ ) import six +from uamqp import types as uamqp_types from azure.core.settings import settings from azure.core.tracing import SpanKind, Link @@ -41,15 +42,12 @@ PROP_PARTITION_KEY ) -# TODO: remove after fixing up async -from uamqp import types -PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) +PROP_PARTITION_KEY_AMQP_SYMBOL = uamqp_types.AMQPSymbol(PROP_PARTITION_KEY) if TYPE_CHECKING: # pylint: disable=ungrouped-imports from ._transport._base import AmqpTransport - from uamqp import types as uamqp_types from azure.core.tracing import AbstractSpan from azure.core.credentials import AzureSasCredential from ._common import EventData @@ -144,7 +142,7 @@ def set_event_partition_key(event, partition_key): annotations = raw_message.annotations if annotations is None: - annotations = dict() + annotations = {} annotations[ PROP_PARTITION_KEY_AMQP_SYMBOL ] = partition_key # pylint:disable=protected-access @@ -154,17 +152,6 @@ def set_event_partition_key(event, partition_key): raw_message.header.durable = True -@contextmanager -def send_context_manager(): - span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] - - if span_impl_type is not None: - with span_impl_type(name="Azure.EventHubs.send", kind=SpanKind.CLIENT) as child: - yield child - else: - yield None - - def trace_message(event, parent_span=None): # type: (EventData, Optional[AbstractSpan]) -> None """Add tracing information to this event. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py index b82ae698ed6b..1e041fde6b9e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py @@ -14,8 +14,6 @@ from uamqp import ( authentication, constants, - errors, - compat, Message, AMQPClientAsync, ) @@ -32,7 +30,7 @@ _get_backoff_time, ) from .._utils import utc_from_timestamp, parse_sas_credential -from ..exceptions import ClientClosedError, ConnectError +from ..exceptions import ClientClosedError from .._constants import ( JWT_TOKEN_SCOPE, MGMT_OPERATION, @@ -260,14 +258,14 @@ async def _create_auth_async(self) -> authentication.JWTTokenAsync: except AttributeError: token_type = b"jwt" if token_type == b"servicebus.windows.net:sastoken": - return await self._amqp_transport.create_token_auth( + return await self._amqp_transport.create_token_auth_async( self._auth_uri, functools.partial(self._credential.get_token, self._auth_uri), token_type=token_type, config=self._config, update_token=True, ) - return await self._amqp_transport.create_token_auth( + return await self._amqp_transport.create_token_auth_async( self._auth_uri, functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE), token_type=token_type, @@ -323,8 +321,8 @@ async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> await asyncio.sleep(0.05) mgmt_msg.application_properties[ "security_token" - ] = await self._amqp_transport.get_updated_token(mgmt_auth) - response = await self._amqp_transport.mgmt_client_request( + ] = await self._amqp_transport.get_updated_token_async(mgmt_auth) + response = await self._amqp_transport.mgmt_client_request_async( mgmt_client, mgmt_msg, operation=READ_OPERATION, @@ -357,7 +355,7 @@ async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> except asyncio.CancelledError: # pylint: disable=try-except-raise raise except Exception as exception: # pylint:disable=broad-except - last_exception = await self._amqp_transport._handle_exception(exception, self) # pylint: disable=protected-access + last_exception = await self._amqp_transport._handle_exception_async(exception, self) # pylint: disable=protected-access await self._backoff_async( retried_times=retried_times, last_exception=last_exception ) @@ -481,7 +479,7 @@ async def _handle_exception(self, exception: Exception) -> Exception: self._amqp_transport.AUTH_EXCEPTION, "Authorization timeout." ) - return await self._amqp_transport._handle_exception( # pylint: disable=protected-access + return await self._amqp_transport._handle_exception_async( # pylint: disable=protected-access exception, self ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py index 32e544344989..ab91e4939967 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------------------------- from typing import TYPE_CHECKING +from asyncio import Lock from uamqp import c_uamqp from uamqp.async_ops import ConnectionAsync diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py index c6f375184b09..91c0d1a29f9d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py @@ -3,12 +3,10 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- from __future__ import annotations -import time -import asyncio import uuid import logging from collections import deque -from typing import TYPE_CHECKING, Callable, Awaitable, cast, Dict, Optional, Union, List +from typing import TYPE_CHECKING, Callable, Awaitable, Dict, Optional, Union, List from ._client_base_async import ConsumerProducerMixin from ._async_utils import get_dict_with_loop_if_needed diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py index 0cf4634a655b..73ef31f79ac4 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py @@ -28,7 +28,6 @@ if TYPE_CHECKING: from ._client_base_async import CredentialTypes - from uamqp.constants import TransportType from ._eventprocessor.partition_context import PartitionContext from ._eventprocessor.checkpoint_store import CheckpointStore from .._common import EventData @@ -233,7 +232,7 @@ def from_connection_string( auth_timeout: float = 60, user_agent: Optional[str] = None, retry_total: int = 3, - transport_type: Optional["TransportType"] = TransportType.Amqp, + transport_type: TransportType = TransportType.Amqp, checkpoint_store: Optional["CheckpointStore"] = None, load_balancing_interval: float = 10, **kwargs: Any diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py index 7c4cd4be1d62..6fb84299bfc2 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py @@ -7,12 +7,10 @@ import asyncio import logging from typing import Iterable, Union, Optional, Any, AnyStr, List, TYPE_CHECKING -import time from azure.core.tracing import AbstractSpan from .._common import EventData, EventDataBatch -from ..exceptions import OperationTimeoutError from .._producer import _set_partition_key, _set_trace_message from .._utils import ( create_properties, @@ -134,7 +132,7 @@ async def _send_event_data( last_exception: Optional[Exception] = None, ) -> None: if self._unsent_events: - await self._amqp_transport.send_messages( + await self._amqp_transport.send_messages_async( self, timeout_time, last_exception, _LOGGER ) @@ -200,7 +198,9 @@ def _wrap_eventdata( event_data, partition_key, self._amqp_transport ) event_data = _set_trace_message(event_data, span) - wrapper_event_data = EventDataBatch._from_batch(event_data, self._amqp_transport, partition_key) # type: ignore # pylint: disable=protected-access + wrapper_event_data = EventDataBatch._from_batch( # type: ignore # pylint: disable=protected-access + event_data, self._amqp_transport, partition_key + ) return wrapper_event_data async def send( diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py index 15e7c345f8be..bfe03aa365aa 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py @@ -21,7 +21,6 @@ if TYPE_CHECKING: from ._client_base_async import CredentialTypes - from uamqp.constants import TransportType # pylint: disable=ungrouped-imports SendEventTypes = List[Union[EventData, AmqpAnnotatedMessage]] @@ -406,7 +405,7 @@ def from_connection_string( auth_timeout: float = 60, user_agent: Optional[str] = None, retry_total: int = 3, - transport_type: Optional["TransportType"] = TransportType.Amqp, + transport_type: TransportType = TransportType.Amqp, **kwargs: Any ) -> "EventHubProducerClient": """Create an EventHubProducerClient from a connection string. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py index 6c1d7837e3f6..b495f2028dc7 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py @@ -93,7 +93,7 @@ def create_send_client(*, config, **kwargs): @staticmethod @abstractmethod - async def send_messages(producer, timeout_time, last_exception, logger): + async def send_messages_async(producer, timeout_time, last_exception, logger): """ Handles sending of event data messages. :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. @@ -159,7 +159,7 @@ async def receive_messages(handler, batch, max_batch_size, max_wait_time): @staticmethod @abstractmethod - async def create_token_auth(auth_uri, get_token, token_type, config, **kwargs): + async def create_token_auth_async(auth_uri, get_token, token_type, config, **kwargs): """ Creates the JWTTokenAuth. :param str auth_uri: The auth uri to pass to JWTTokenAuth. @@ -184,7 +184,7 @@ def create_mgmt_client(address, mgmt_auth, config): @staticmethod @abstractmethod - async def get_updated_token(mgmt_auth): + async def get_updated_token_async(mgmt_auth): """ Return updated auth token. :param mgmt_auth: Auth. @@ -192,7 +192,7 @@ async def get_updated_token(mgmt_auth): @staticmethod @abstractmethod - async def mgmt_client_request(mgmt_client, mgmt_msg, **kwargs): + async def mgmt_client_request_async(mgmt_client, mgmt_msg, **kwargs): """ Send mgmt request. :param AMQP Client mgmt_client: Client to send request with. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py index 7a57a44ef361..61fed458e44a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py @@ -9,45 +9,22 @@ import logging from typing import Union, cast, TYPE_CHECKING, List -try: - from uamqp import ( - BatchMessage, - constants, - MessageBodyType, - Message, - types, - SendClientAsync, - ReceiveClientAsync, - Source, - utils, - authentication, - AMQPClientAsync, - compat, - errors, - ) - from uamqp.message import ( - MessageHeader, - MessageProperties, - ) - uamqp_installed = True -except ImportError: - uamqp_installed = False - -from ._base_async import AmqpTransportAsync -from ...amqp._constants import AmqpMessageBodyType -from ..._constants import ( - NO_RETRY_ERRORS, - PROP_PARTITION_KEY, +from uamqp import ( + constants, + types, + SendClientAsync, + ReceiveClientAsync, + utils, + authentication, + AMQPClientAsync, + errors, ) +from ._base_async import AmqpTransportAsync +from ..._transport._uamqp_transport import UamqpTransport from ...exceptions import ( - ConnectError, - EventDataError, - EventDataSendError, OperationTimeoutError, EventHubError, - AuthenticationError, - ConnectionLostError, EventDataError, EventDataSendError, ) @@ -58,307 +35,303 @@ _LOGGER = logging.getLogger(__name__) -if uamqp_installed: - - from ..._transport._uamqp_transport import UamqpTransport +class UamqpTransportAsync(UamqpTransport, AmqpTransportAsync): + """ + Class which defines uamqp-based methods used by the producer and consumer. + """ - class UamqpTransportAsync(UamqpTransport, AmqpTransportAsync): - """ - Class which defines uamqp-based methods used by the producer and consumer. + @staticmethod + def create_send_client(*, config, **kwargs): # pylint:disable=unused-argument """ + Creates and returns the uamqp SendClient. + :param ~azure.eventhub._configuration.Configuration config: The configuration. - @staticmethod - def create_send_client(*, config, **kwargs): # pylint:disable=unused-argument - """ - Creates and returns the uamqp SendClient. - :param ~azure.eventhub._configuration.Configuration config: The configuration. - - :keyword str target: Required. The target. - :keyword JWTTokenAuth auth: Required. - :keyword int idle_timeout: Required. - :keyword network_trace: Required. - :keyword retry_policy: Required. - :keyword keep_alive_interval: Required. - :keyword str client_name: Required. - :keyword dict link_properties: Required. - :keyword properties: Required. - """ - target = kwargs.pop("target") - retry_policy = kwargs.pop("retry_policy") - network_trace = kwargs.pop("network_trace") + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + target = kwargs.pop("target") + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") - return SendClientAsync( - target, - debug=network_trace, # pylint:disable=protected-access - error_policy=retry_policy, - **kwargs - ) + return SendClientAsync( + target, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + **kwargs + ) - @staticmethod - async def send_messages(producer, timeout_time, last_exception, logger): - """ - Handles sending of event data messages. - :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. - :param int timeout_time: Timeout time. - :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. - :param logger: Logger. - """ - # pylint: disable=protected-access - await producer._open() - producer._unsent_events[0].on_send_complete = producer._on_outcome - UamqpTransportAsync._set_msg_timeout(producer, timeout_time, last_exception, logger) - producer._handler.queue_message(*producer._unsent_events) # type: ignore - await producer._handler.wait_async() # type: ignore - producer._unsent_events = producer._handler.pending_messages # type: ignore - if producer._outcome != constants.MessageSendResult.Ok: - if producer._outcome == constants.MessageSendResult.Timeout: - producer._condition = OperationTimeoutError("Send operation timed out") - if producer._condition: - raise producer._condition + @staticmethod + async def send_messages_async(producer, timeout_time, last_exception, logger): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + # pylint: disable=protected-access + await producer._open() + producer._unsent_events[0].on_send_complete = producer._on_outcome + UamqpTransportAsync._set_msg_timeout(producer, timeout_time, last_exception, logger) + producer._handler.queue_message(*producer._unsent_events) # type: ignore + await producer._handler.wait_async() # type: ignore + producer._unsent_events = producer._handler.pending_messages # type: ignore + if producer._outcome != constants.MessageSendResult.Ok: + if producer._outcome == constants.MessageSendResult.Timeout: + producer._condition = OperationTimeoutError("Send operation timed out") + if producer._condition: + raise producer._condition - @staticmethod - def create_receive_client(*, config, **kwargs): - """ - Creates and returns the receive client. - :param ~azure.eventhub._configuration.Configuration config: The configuration. + @staticmethod + def create_receive_client(*, config, **kwargs): # pylint:disable=unused-argument + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. - :keyword str source: Required. The source. - :keyword str offset: Required. - :keyword str offset_inclusive: Required. - :keyword JWTTokenAuth auth: Required. - :keyword int idle_timeout: Required. - :keyword network_trace: Required. - :keyword retry_policy: Required. - :keyword str client_name: Required. - :keyword dict link_properties: Required. - :keyword properties: Required. - :keyword link_credit: Required. The prefetch. - :keyword keep_alive_interval: Required. - :keyword desired_capabilities: Required. - :keyword streaming_receive: Required. - :keyword message_received_callback: Required. - :keyword timeout: Required. - """ + :keyword str source: Required. The source. + :keyword str offset: Required. + :keyword str offset_inclusive: Required. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ - source = kwargs.pop("source") - symbol_array = kwargs.pop("desired_capabilities") - desired_capabilities = None - if symbol_array: - symbol_array = [types.AMQPSymbol(symbol) for symbol in symbol_array] - desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) - retry_policy = kwargs.pop("retry_policy") - network_trace = kwargs.pop("network_trace") - link_credit = kwargs.pop("link_credit") - streaming_receive = kwargs.pop("streaming_receive") - message_received_callback = kwargs.pop("message_received_callback") + source = kwargs.pop("source") + symbol_array = kwargs.pop("desired_capabilities") + desired_capabilities = None + if symbol_array: + symbol_array = [types.AMQPSymbol(symbol) for symbol in symbol_array] + desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + link_credit = kwargs.pop("link_credit") + streaming_receive = kwargs.pop("streaming_receive") + message_received_callback = kwargs.pop("message_received_callback") - client = ReceiveClientAsync( - source, - debug=network_trace, # pylint:disable=protected-access - error_policy=retry_policy, - desired_capabilities=desired_capabilities, - prefetch=link_credit, - receive_settle_mode=constants.ReceiverSettleMode.ReceiveAndDelete, - auto_complete=False, - **kwargs - ) - # pylint:disable=protected-access - client._streaming_receive = streaming_receive - client._message_received_callback = (message_received_callback) - return client + client = ReceiveClientAsync( + source, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + desired_capabilities=desired_capabilities, + prefetch=link_credit, + receive_settle_mode=constants.ReceiverSettleMode.ReceiveAndDelete, + auto_complete=False, + **kwargs + ) + # pylint:disable=protected-access + client._streaming_receive = streaming_receive + client._message_received_callback = (message_received_callback) + return client - @staticmethod - async def receive_messages(handler, batch, max_batch_size, max_wait_time): - """ - Receives messages, creates events, and returns them by calling the on received callback. - :param ReceiveClient handler: The receive client. - :param bool batch: If receive batch or single event. - :param int max_batch_size: Max batch size. - :param int or None max_wait_time: Max wait time. - """ - # pylint:disable=protected-access - max_retries = ( - handler._client._config.max_retries # pylint:disable=protected-access - ) - has_not_fetched_once = True # ensure one trip when max_wait_time is very small - deadline = time.time() + (max_wait_time or 0) # max_wait_time can be None - while len(handler._message_buffer) < max_batch_size and ( - time.time() < deadline or has_not_fetched_once - ): - retried_times = 0 - has_not_fetched_once = False - while retried_times <= max_retries: - try: - await handler._open() - await cast( - ReceiveClientAsync, handler._handler - ).do_work_async() # uamqp sleeps 0.05 if none received - break - except asyncio.CancelledError: # pylint: disable=try-except-raise - raise - except Exception as exception: # pylint: disable=broad-except - if ( - isinstance(exception, UamqpTransportAsync.AMQP_LINK_ERROR) - and exception.condition == UamqpTransportAsync.LINK_STOLEN_CONDITION # pylint: disable=no-member - ): - raise await handler._handle_exception(exception) - if not handler.running: # exit by close - return - if handler._last_received_event: - handler._offset = handler._last_received_event.offset - last_exception = await handler._handle_exception(exception) - retried_times += 1 - if retried_times > max_retries: - _LOGGER.info( - "%r operation has exhausted retry. Last exception: %r.", - handler._name, - last_exception, - ) - raise last_exception + @staticmethod + async def receive_messages(handler, batch, max_batch_size, max_wait_time): + """ + Receives messages, creates events, and returns them by calling the on received callback. + :param ReceiveClient handler: The receive client. + :param bool batch: If receive batch or single event. + :param int max_batch_size: Max batch size. + :param int or None max_wait_time: Max wait time. + """ + # pylint:disable=protected-access + max_retries = ( + handler._client._config.max_retries # pylint:disable=protected-access + ) + has_not_fetched_once = True # ensure one trip when max_wait_time is very small + deadline = time.time() + (max_wait_time or 0) # max_wait_time can be None + while len(handler._message_buffer) < max_batch_size and ( + time.time() < deadline or has_not_fetched_once + ): + retried_times = 0 + has_not_fetched_once = False + while retried_times <= max_retries: + try: + await handler._open() + await cast( + ReceiveClientAsync, handler._handler + ).do_work_async() # uamqp sleeps 0.05 if none received + break + except asyncio.CancelledError: # pylint: disable=try-except-raise + raise + except Exception as exception: # pylint: disable=broad-except + if ( + isinstance(exception, UamqpTransportAsync.AMQP_LINK_ERROR) + and exception.condition == UamqpTransportAsync.LINK_STOLEN_CONDITION # pylint: disable=no-member + ): + raise await handler._handle_exception(exception) + if not handler.running: # exit by close + return + if handler._last_received_event: + handler._offset = handler._last_received_event.offset + last_exception = await handler._handle_exception(exception) + retried_times += 1 + if retried_times > max_retries: + _LOGGER.info( + "%r operation has exhausted retry. Last exception: %r.", + handler._name, + last_exception, + ) + raise last_exception - if handler._message_buffer: - while handler._message_buffer: - if batch: - events_for_callback: List[EventData] = [] - for _ in range(min(max_batch_size, len(handler._message_buffer))): - events_for_callback.append(handler._next_message_in_buffer()) - await handler._on_event_received(events_for_callback) - else: - await handler._on_event_received(handler._next_message_in_buffer()) - elif max_wait_time: + if handler._message_buffer: + while handler._message_buffer: if batch: - await handler._on_event_received([]) + events_for_callback: List[EventData] = [] + for _ in range(min(max_batch_size, len(handler._message_buffer))): + events_for_callback.append(handler._next_message_in_buffer()) + await handler._on_event_received(events_for_callback) else: - await handler._on_event_received(None) + await handler._on_event_received(handler._next_message_in_buffer()) + elif max_wait_time: + if batch: + await handler._on_event_received([]) + else: + await handler._on_event_received(None) - @staticmethod - async def create_token_auth(auth_uri, get_token, token_type, config, **kwargs): - """ - Creates the JWTTokenAuth. - :param str auth_uri: The auth uri to pass to JWTTokenAuth. - :param get_token: The callback function used for getting and refreshing - tokens. It should return a valid jwt token each time it is called. - :param bytes token_type: Token type. - :param ~azure.eventhub._configuration.Configuration config: EH config. + @staticmethod + async def create_token_auth_async(auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. - :keyword bool update_token: Required. Whether to update token. If not updating token, - then pass 300 to refresh_window. - """ - update_token = kwargs.pop("update_token") - refresh_window = 300 - if update_token: - refresh_window = 0 + :keyword bool update_token: Required. Whether to update token. If not updating token, + then pass 300 to refresh_window. + """ + update_token = kwargs.pop("update_token") + refresh_window = 300 + if update_token: + refresh_window = 0 - token_auth = authentication.JWTTokenAsync( - auth_uri, - auth_uri, - get_token, - token_type=token_type, - timeout=config.auth_timeout, - http_proxy=config.http_proxy, - transport_type=config.transport_type, - custom_endpoint_hostname=config.custom_endpoint_hostname, - port=config.connection_port, - verify=config.connection_verify, - refresh_window=refresh_window - ) - if update_token: - await token_auth.update_token() - return token_auth + token_auth = authentication.JWTTokenAsync( + auth_uri, + auth_uri, + get_token, + token_type=token_type, + timeout=config.auth_timeout, + http_proxy=config.http_proxy, + transport_type=config.transport_type, + custom_endpoint_hostname=config.custom_endpoint_hostname, + port=config.connection_port, + verify=config.connection_verify, + refresh_window=refresh_window + ) + if update_token: + await token_auth.update_token() + return token_auth - @staticmethod - def create_mgmt_client(address, mgmt_auth, config): - """ - Creates and returns the mgmt AMQP client. - :param _Address address: Required. The Address. - :param JWTTokenAuth mgmt_auth: Auth for client. - :param ~azure.eventhub._configuration.Configuration config: The configuration. - """ + @staticmethod + def create_mgmt_client(address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ - mgmt_target = f"amqps://{address.hostname}{address.path}" - return AMQPClientAsync( - mgmt_target, - auth=mgmt_auth, - debug=config.network_tracing - ) + mgmt_target = f"amqps://{address.hostname}{address.path}" + return AMQPClientAsync( + mgmt_target, + auth=mgmt_auth, + debug=config.network_tracing + ) - @staticmethod - async def get_updated_token(mgmt_auth): - """ - Return updated auth token. - :param mgmt_auth: Auth. - """ - return mgmt_auth.token + @staticmethod + async def get_updated_token_async(mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + return mgmt_auth.token - @staticmethod - async def mgmt_client_request(mgmt_client, mgmt_msg, **kwargs): - """ - Send mgmt request. - :param AMQP Client mgmt_client: Client to send request with. - :param str mgmt_msg: Message. - :keyword bytes operation: Operation. - :keyword operation_type: Op type. - :keyword status_code_field: mgmt status code. - :keyword description_fields: mgmt status desc. - """ - operation_type = kwargs.pop("operation_type") - operation = kwargs.pop("operation") - return await mgmt_client.mgmt_request_async( - mgmt_msg, - operation, - op_type=operation_type, - **kwargs - ) + @staticmethod + async def mgmt_client_request_async(mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + operation_type = kwargs.pop("operation_type") + operation = kwargs.pop("operation") + return await mgmt_client.mgmt_request_async( + mgmt_msg, + operation, + op_type=operation_type, + **kwargs + ) - @staticmethod - async def _handle_exception( # pylint:disable=too-many-branches, too-many-statements - exception: Exception, closable: Union["ClientBaseAsync", "ConsumerProducerMixin"] - ) -> Exception: - # pylint: disable=protected-access - if isinstance(exception, asyncio.CancelledError): - raise exception - error = exception + @staticmethod + async def _handle_exception_async( # pylint:disable=too-many-branches, too-many-statements + exception: Exception, closable: Union["ClientBaseAsync", "ConsumerProducerMixin"] + ) -> Exception: + # pylint: disable=protected-access + if isinstance(exception, asyncio.CancelledError): + raise exception + error = exception + try: + name = cast("ConsumerProducerMixin", closable)._name + except AttributeError: + name = cast("ClientBaseAsync", closable)._container_id + if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise + _LOGGER.info("%r stops due to keyboard interrupt", name) + await cast("ConsumerProducerMixin", closable)._close_connection_async() + raise error + elif isinstance(exception, EventHubError): + await cast("ConsumerProducerMixin", closable)._close_handler_async() + raise error + elif isinstance( + exception, + ( + errors.MessageAccepted, + errors.MessageAlreadySettled, + errors.MessageModified, + errors.MessageRejected, + errors.MessageReleased, + errors.MessageContentTooLarge, + ), + ): + _LOGGER.info("%r Event data error (%r)", name, exception) + error = EventDataError(str(exception), exception) + raise error + elif isinstance(exception, errors.MessageException): + _LOGGER.info("%r Event data send error (%r)", name, exception) + error = EventDataSendError(str(exception), exception) + raise error + else: try: - name = cast("ConsumerProducerMixin", closable)._name + if isinstance(exception, errors.AuthenticationException): + await closable._close_connection_async() + elif isinstance(exception, errors.LinkDetach): + await cast("ConsumerProducerMixin", closable)._close_handler_async() + elif isinstance(exception, errors.ConnectionClose): + await closable._close_connection_async() + elif isinstance(exception, errors.MessageHandlerError): + await cast("ConsumerProducerMixin", closable)._close_handler_async() + else: # errors.AMQPConnectionError, compat.TimeoutException, and any other errors + await closable._close_connection_async() except AttributeError: - name = cast("ClientBaseAsync", closable)._container_id - if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise - _LOGGER.info("%r stops due to keyboard interrupt", name) - await cast("ConsumerProducerMixin", closable)._close_connection_async() - raise error - elif isinstance(exception, EventHubError): - await cast("ConsumerProducerMixin", closable)._close_handler_async() - raise error - elif isinstance( - exception, - ( - errors.MessageAccepted, - errors.MessageAlreadySettled, - errors.MessageModified, - errors.MessageRejected, - errors.MessageReleased, - errors.MessageContentTooLarge, - ), - ): - _LOGGER.info("%r Event data error (%r)", name, exception) - error = EventDataError(str(exception), exception) - raise error - elif isinstance(exception, errors.MessageException): - _LOGGER.info("%r Event data send error (%r)", name, exception) - error = EventDataSendError(str(exception), exception) - raise error - else: - try: - if isinstance(exception, errors.AuthenticationException): - await closable._close_connection_async() - elif isinstance(exception, errors.LinkDetach): - await cast("ConsumerProducerMixin", closable)._close_handler_async() - elif isinstance(exception, errors.ConnectionClose): - await closable._close_connection_async() - elif isinstance(exception, errors.MessageHandlerError): - await cast("ConsumerProducerMixin", closable)._close_handler_async() - else: # errors.AMQPConnectionError, compat.TimeoutException, and any other errors - await closable._close_connection_async() - except AttributeError: - pass - return UamqpTransport._create_eventhub_exception(exception) \ No newline at end of file + pass + return UamqpTransport._create_eventhub_exception(exception) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py index f0846f9b49fd..c883a55f5cc3 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py @@ -95,9 +95,9 @@ def __init__(self, **kwargs): def __str__(self) -> str: if self._body_type == AmqpMessageBodyType.DATA: return "".join(d.decode(self._encoding) for d in self._data_body) - elif self._body_type == AmqpMessageBodyType.SEQUENCE: + if self._body_type == AmqpMessageBodyType.SEQUENCE: return str(self._sequence_body) - elif self._body_type == AmqpMessageBodyType.VALUE: + if self._body_type == AmqpMessageBodyType.VALUE: return str(self._value_body) return "" @@ -174,7 +174,6 @@ def _from_amqp_message(self, message): @property def body(self) -> Any: - # type: () -> Any """The body of the Message. The format may vary depending on the body type: For ~azure.eventhub.AmqpMessageBodyType.DATA, the body could be bytes or Iterable[bytes] For ~azure.eventhub.AmqpMessageBodyType.SEQUENCE, the body could be List or Iterable[List] diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py index 4bb676392f89..c620c149ea5e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py @@ -15,13 +15,11 @@ def normalized_data_body(data, **kwargs): encoding = kwargs.get("encoding", "utf-8") if isinstance(data, list): return [encode_str(item, encoding) for item in data] - else: - return [encode_str(data, encoding)] - + return [encode_str(data, encoding)] def normalized_sequence_body(sequence): # A helper method to normalize input into AMQP Sequence Body format if isinstance(sequence, list) and all([isinstance(b, list) for b in sequence]): return sequence - elif isinstance(sequence, list): + if isinstance(sequence, list): return [sequence] diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py index a688e70a8861..f686251e6e95 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py @@ -3,11 +3,6 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- import six -try: - from uamqp import errors, compat -except ImportError: - errors = None - compat = None class EventHubError(Exception): """Represents an error occurred in the client. From f9e9a282adbb20e93dd59265fc52ae405eeffab3 Mon Sep 17 00:00:00 2001 From: swathipil Date: Mon, 8 Aug 2022 17:53:28 -0700 Subject: [PATCH 11/20] address Anna/Libba/Kashifs comments --- .../_buffered_producer/_buffered_producer.py | 4 +- .../_buffered_producer_dispatcher.py | 4 +- .../azure/eventhub/_client_base.py | 15 ++---- .../azure-eventhub/azure/eventhub/_common.py | 9 +--- .../azure/eventhub/_connection_manager.py | 43 +++++++-------- .../azure/eventhub/_consumer.py | 5 +- .../azure/eventhub/_producer.py | 2 +- .../azure/eventhub/_producer_client.py | 2 +- .../azure/eventhub/_transport/_base.py | 36 +++++++++++++ .../eventhub/_transport/_uamqp_transport.py | 50 ++++++++++++++++++ .../azure-eventhub/azure/eventhub/_utils.py | 52 ++++++++++++++++--- .../azure/eventhub/aio/_client_base_async.py | 4 +- .../eventhub/aio/_connection_manager_async.py | 25 ++++----- .../azure/eventhub/aio/_consumer_async.py | 5 +- .../azure/eventhub/aio/_producer_async.py | 2 +- .../eventhub/aio/_producer_client_async.py | 2 +- .../eventhub/aio/_transport/_base_async.py | 35 +++++++++++++ .../aio/_transport/_uamqp_transport_async.py | 34 ++++++++++++ .../azure/eventhub/amqp/_amqp_message.py | 18 +++---- 19 files changed, 259 insertions(+), 88 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py index 7baeaf085d3e..4eb2cc55b1b0 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py @@ -32,9 +32,9 @@ def __init__( max_message_size_on_link: int, executor: ThreadPoolExecutor, *, - max_wait_time: float = 1, + amqp_transport: AmqpTransport, max_buffer_length: int, - amqp_transport: AmqpTransport + max_wait_time: float = 1 ): self._buffered_queue: queue.Queue = queue.Queue() self._max_buffer_len = max_buffer_length diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py index 01b300560dc0..3690919e04fa 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py @@ -31,10 +31,10 @@ def __init__( eventhub_name: str, max_message_size_on_link: int, *, + amqp_transport: AmqpTransport, max_buffer_length: int = 1500, max_wait_time: float = 1, - executor: Optional[Union[ThreadPoolExecutor, int]] = None, - amqp_transport: AmqpTransport + executor: Optional[Union[ThreadPoolExecutor, int]] = None ): self._buffered_producers: Dict[str, BufferedProducer] = {} self._partition_ids: List[str] = partitions diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index b1ec0651cf51..bf9219c1729e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -12,8 +12,6 @@ from typing import Any, Dict, Tuple, List, Optional, TYPE_CHECKING, cast, Union from datetime import timedelta from urllib.parse import urlparse, quote_plus -import six -from uamqp import utils from azure.core.credentials import ( AccessToken, @@ -27,7 +25,7 @@ from ._transport._uamqp_transport import UamqpTransport from .exceptions import ClientClosedError from ._configuration import Configuration -from ._utils import utc_from_timestamp, parse_sas_credential +from ._utils import utc_from_timestamp, parse_sas_credential, generate_sas_token from ._connection_manager import get_connection_manager from ._constants import ( CONTAINER_PREFIX, @@ -150,11 +148,8 @@ def _generate_sas_token(uri, policy, key, expiry=None): expiry = timedelta(hours=1) # Default to 1 hour. abs_expiry = int(time.time()) + expiry.seconds - encoded_uri = quote_plus(uri).encode("utf-8") # pylint: disable=no-member - encoded_policy = quote_plus(policy).encode("utf-8") # pylint: disable=no-member - encoded_key = key.encode("utf-8") - token = utils.create_sas_token(encoded_policy, encoded_key, encoded_uri, expiry) + token = generate_sas_token(uri, policy, key, abs_expiry).encode() return AccessToken(token=token, expires_on=abs_expiry) @@ -289,7 +284,7 @@ def __init__( credential: CredentialTypes, **kwargs: Any, ) -> None: - self._uamqp_transport = kwargs.pop("uamqp_transport", True) + uamqp_transport = kwargs.pop("uamqp_transport", True) self._amqp_transport = kwargs.pop("amqp_transport", UamqpTransport) self.eventhub_name = eventhub_name @@ -308,7 +303,7 @@ def __init__( self._auto_reconnect = kwargs.get("auto_reconnect", True) self._auth_uri = f"sb://{self._address.hostname}{self._address.path}" self._config = Configuration( - uamqp_transport=self._uamqp_transport, + uamqp_transport=uamqp_transport, hostname=self._address.hostname, **kwargs, ) @@ -418,7 +413,7 @@ def _management_request( description = response.application_properties.get( MGMT_STATUS_DESC ) # type: Optional[Union[str, bytes]] - if description and isinstance(description, six.binary_type): + if description and isinstance(description, bytes): description = description.decode("utf-8") if status_code < 400: return response diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index a616197aa870..76f553a68e9f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -55,6 +55,7 @@ AmqpMessageProperties, ) from ._transport._uamqp_transport import UamqpTransport +from uamqp import types if TYPE_CHECKING: from uamqp import Message as uamqp_Message, BatchMessage as uamqp_BatchMessage @@ -224,8 +225,8 @@ def _from_message( :rtype: ~azure.eventhub.EventData """ event_data = cls(body="") - event_data._message = message # pylint: disable=protected-access + event_data._message = message event_data._raw_amqp_message = ( raw_amqp_message if raw_amqp_message @@ -293,11 +294,6 @@ def partition_key(self) -> Optional[bytes]: :rtype: bytes """ - # TODO: Ask Anna. can we remove the try and just do except? Haven't seen a case where symbol is used to get. - # try: - # return self._raw_amqp_message.annotations[types.AMQPSymbol(PROP_PARTITION_KEY)] - # except KeyError: - # return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None) return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None) @property @@ -310,7 +306,6 @@ def properties(self) -> Dict[Union[str, bytes], Any]: @properties.setter def properties(self, value: Dict[Union[str, bytes], Any]): - # type: (Dict[Union[str, bytes], Any]) -> None """Application-defined properties on the event. :param dict value: The application properties for the EventData. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py index 0613390f9bda..016fed03d4f4 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py @@ -3,25 +3,26 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations from typing import TYPE_CHECKING from threading import Lock from enum import Enum -from uamqp import c_uamqp, Connection as uamqp_Connection +from ._transport._uamqp_transport import UamqpTransport from ._constants import TransportType if TYPE_CHECKING: + from uamqp.authentication import JWTTokenAuth + from uamqp import Connection + try: from typing_extensions import Protocol except ImportError: Protocol = object # type: ignore - from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth - class ConnectionManager(Protocol): - def get_connection( - self, host: str, auth: uamqp_JWTTokenAuth - ) -> uamqp_Connection: + def get_connection(self, host, auth): + # type: (str, 'JWTTokenAuth') -> Connection pass def close_connection(self): @@ -39,10 +40,7 @@ class _ConnectionMode(Enum): class _SharedConnectionManager(object): # pylint:disable=too-many-instance-attributes def __init__(self, **kwargs): self._lock = Lock() - self._conn: uamqp_Connection = None - - self._lock = Lock() - self._conn = None # type: uamqp_Connection + self._conn: Connection = None self._container_id = kwargs.get("container_id") self._debug = kwargs.get("debug") @@ -57,14 +55,15 @@ def __init__(self, **kwargs): self._remote_idle_timeout_empty_frame_send_ratio = kwargs.get( "remote_idle_timeout_empty_frame_send_ratio" ) + self._amqp_transport = kwargs.get("amqp_transport", UamqpTransport) - def get_connection(self, host, auth): - # type: (str, uamqp_JWTTokenAuth) -> uamqp_Connection + def get_connection(self, *, host: str, auth: JWTTokenAuth, endpoint: str) -> Connection: with self._lock: if self._conn is None: - self._conn = uamqp_Connection( - host, - auth, + self._conn = self._amqp_transport.create_connection( + host=host, + auth=auth, + endpoint=endpoint, container_id=self._container_id, max_frame_size=self._max_frame_size, channel_max=self._channel_max, @@ -81,18 +80,14 @@ def close_connection(self): # type: () -> None with self._lock: if self._conn: - self._conn.destroy() + self._amqp_transport.close_connection(self._conn) self._conn = None def reset_connection_if_broken(self): # type: () -> None with self._lock: - if self._conn and self._conn._state in ( # pylint:disable=protected-access - c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member - c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member - c_uamqp.ConnectionState.DISCARDING, # pylint:disable=c-extension-no-member - c_uamqp.ConnectionState.END, # pylint:disable=c-extension-no-member - ): + conn_state = self._amqp_transport.get_connection_state(self._conn) + if self._conn and conn_state in self._amqp_transport.CONNECTION_CLOSING_STATES: self._conn = None @@ -100,8 +95,8 @@ class _SeparateConnectionManager(object): def __init__(self, **kwargs): pass - def get_connection(self, endpoint): # pylint:disable=unused-argument, no-self-use - # type: (str) -> None + def get_connection(self, host, auth): # pylint:disable=unused-argument, no-self-use + # type: (str, JWTTokenAuth) -> None return None def close_connection(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index 98d3fe21c86b..1cc6656c6604 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -99,9 +99,8 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs: Any) self._error = None self._timeout = 0 self._idle_timeout = (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) if idle_timeout else None - partition = self._source.split("/")[-1] - self._partition = partition - self._name = f"EHConsumer-{uuid.uuid4()}-partition{partition}" + self._partition = self._source.split("/")[-1] + self._name = f"EHConsumer-{uuid.uuid4()}-partition{self._partition}" if owner_level is not None: link_properties[EPOCH_SYMBOL] = int(owner_level) link_property_timeout_ms = ( diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index 4e990bf78e50..86549929113c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -125,7 +125,7 @@ def __init__( self._condition: Optional[Exception] = None self._lock = threading.Lock() self._link_properties = self._amqp_transport.create_link_properties( - {TIMEOUT_SYMBOL: int(self._timeout * 1000)} + {TIMEOUT_SYMBOL: int(self._timeout * self._amqp_transport.TIMEOUT_FACTOR)} ) def _create_handler( diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index 38536c9483e9..76d975d04cf6 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -295,7 +295,7 @@ def _buffered_send_batch(self, event_data_batch, **kwargs): def _buffered_send_event(self, event, **kwargs): partition_key = kwargs.get("partition_key") - set_event_partition_key(event, partition_key) + set_event_partition_key(event, partition_key, self._amqp_transport) timeout = kwargs.get("timeout") timeout_time = time.time() + timeout if timeout else None self._buffered_send( diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py index 38266f53027d..290503e171a5 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py @@ -13,6 +13,7 @@ class AmqpTransport(ABC): MAX_FRAME_SIZE_BYTES = None IDLE_TIMEOUT_FACTOR = None MESSAGE = None + CONNECTION_CLOSING_STATES = None # define symbols PRODUCT_SYMBOL = None @@ -73,6 +74,41 @@ def create_link_properties(link_properties): :rtype: dict """ + @staticmethod + @abstractmethod + def create_connection(**kwargs): + """ + Creates and returns the uamqp Connection object. + :keyword str host: The hostname, used by uamqp. + :keyword JWTTokenAuth auth: The auth, used by uamqp. + :keyword str endpoint: The endpoint, used by pyamqp. + :keyword str container_id: Required. + :keyword int max_frame_size: Required. + :keyword int channel_max: Required. + :keyword int idle_timeout: Required. + :keyword Dict properties: Required. + :keyword int remote_idle_timeout_empty_frame_send_ratio: Required. + :keyword error_policy: Required. + :keyword bool debug: Required. + :keyword str encoding: Required. + """ + + @staticmethod + @abstractmethod + def close_connection(connection): + """ + Closes existing connection. + :param connection: uamqp or pyamqp Connection. + """ + + @staticmethod + @abstractmethod + def get_connection_state(connection): + """ + Gets connection state. + :param connection: uamqp or pyamqp Connection. + """ + @staticmethod @abstractmethod def create_send_client(*, config, **kwargs): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index a3ef0d18f52a..24c4c80598b0 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -9,6 +9,7 @@ try: from uamqp import ( + c_uamqp, BatchMessage, constants, MessageBodyType, @@ -22,6 +23,7 @@ AMQPClient, compat, errors, + Connection, ) from uamqp.message import ( MessageHeader, @@ -84,6 +86,12 @@ class UamqpTransport(AmqpTransport): MAX_FRAME_SIZE_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES IDLE_TIMEOUT_FACTOR = 1000 MESSAGE = Message + CONNECTION_CLOSING_STATES = ( # pylint:disable=protected-access + c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member + c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member + c_uamqp.ConnectionState.DISCARDING, # pylint:disable=c-extension-no-member + c_uamqp.ConnectionState.END, # pylint:disable=c-extension-no-member + ) # define symbols PRODUCT_SYMBOL = types.AMQPSymbol("product") @@ -205,6 +213,48 @@ def create_link_properties(link_properties): """ return {types.AMQPSymbol(symbol): types.AMQPLong(value) for (symbol, value) in link_properties.items()} + @staticmethod + def create_connection(**kwargs): + """ + Creates and returns the uamqp Connection object. + :keyword str host: The hostname, used by uamqp. + :keyword JWTTokenAuth auth: The auth, used by uamqp. + :keyword str endpoint: The endpoint, used by pyamqp. + :keyword str container_id: Required. + :keyword int max_frame_size: Required. + :keyword int channel_max: Required. + :keyword int idle_timeout: Required. + :keyword Dict properties: Required. + :keyword int remote_idle_timeout_empty_frame_send_ratio: Required. + :keyword error_policy: Required. + :keyword bool debug: Required. + :keyword str encoding: Required. + """ + endpoint = kwargs.pop("endpoint") # pylint:disable=unused-variable + host = kwargs.pop("host") + auth = kwargs.pop("auth") + return Connection( + host, + auth, + **kwargs + ) + + @staticmethod + def close_connection(connection): + """ + Closes existing connection. + :param connection: uamqp or pyamqp Connection. + """ + connection.destroy() + + @staticmethod + def get_connection_state(connection): + """ + Gets connection state. + :param connection: uamqp or pyamqp Connection. + """ + return connection._state + @staticmethod def create_send_client(*, config, **kwargs): # pylint:disable=unused-argument """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index de4502be188a..7815e71f121b 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -10,6 +10,11 @@ import datetime import calendar import logging +from base64 import b64encode +from hashlib import sha256 +from hmac import HMAC +from urllib.parse import urlencode, quote_plus +import time from typing import ( TYPE_CHECKING, Type, @@ -39,11 +44,8 @@ PROP_RUNTIME_INFO_RETRIEVAL_TIME_UTC, PROP_LAST_ENQUEUED_OFFSET, PROP_TIMESTAMP, - PROP_PARTITION_KEY ) -PROP_PARTITION_KEY_AMQP_SYMBOL = uamqp_types.AMQPSymbol(PROP_PARTITION_KEY) - if TYPE_CHECKING: # pylint: disable=ungrouped-imports @@ -130,8 +132,11 @@ def send_context_manager(): yield None -def set_event_partition_key(event, partition_key): - # type: (Union[AmqpAnnotatedMessage, EventData], Optional[Union[bytes, str]]) -> None +def set_event_partition_key( + event: Union[AmqpAnnotatedMessage, EventData], + partition_key: Optional[Union[bytes, str]], + amqp_transport: AmqpTransport +) -> None: if not partition_key: return @@ -144,7 +149,7 @@ def set_event_partition_key(event, partition_key): if annotations is None: annotations = {} annotations[ - PROP_PARTITION_KEY_AMQP_SYMBOL + amqp_transport.PROP_PARTITION_KEY_AMQP_SYMBOL ] = partition_key # pylint:disable=protected-access if not raw_message.header: raw_message.header = AmqpMessageHeader(header=True) @@ -289,12 +294,14 @@ def transform_outbound_single_message(message, message_type, to_outgoing_amqp_me """ try: # pylint: disable=protected-access - # EventData + # If EventData, set EventData._message to uamqp/pyamqp.Message right before sending. message._message = to_outgoing_amqp_message(message.raw_amqp_message) return message # type: ignore except AttributeError: # pylint: disable=protected-access - # AmqpAnnotatedMessage + # If AmqpAnnotatedMessage, create EventData object with _from_message. + # event_data._message will be set to outgoing uamqp/pyamqp.Message. + # event_data.raw_amqp_message will be set to AmqpAnnotatedMessage. amqp_message = to_outgoing_amqp_message(message) return message_type._from_message( message=amqp_message, raw_amqp_message=message # type: ignore @@ -335,3 +342,32 @@ def decode_with_recurse(data, encoding="UTF-8"): return decoded_list return data + + +def generate_sas_token(audience, policy, key, expiry=None): + """ + Generate a sas token according to the given audience, policy, key and expiry + :param str audience: + :param str policy: + :param str key: + :param int expiry: abs expiry time + :rtype: str + """ + if not expiry: + expiry = int(time.time()) + 3600 # Default to 1 hour. + + encoded_uri = quote_plus(audience) + encoded_policy = quote_plus(policy).encode("utf-8") + encoded_key = key.encode("utf-8") + + ttl = int(expiry) + sign_key = '%s\n%d' % (encoded_uri, ttl) + signature = b64encode(HMAC(encoded_key, sign_key.encode('utf-8'), sha256).digest()) + result = { + 'sr': audience, + 'sig': signature, + 'se': str(ttl) + } + if policy: + result['skn'] = encoded_policy + return 'SharedAccessSignature ' + urlencode(result) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py index 1e041fde6b9e..ee91735ab607 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py @@ -210,7 +210,7 @@ def __init__( **kwargs: Any ) -> None: self._internal_kwargs = get_dict_with_loop_if_needed(kwargs.get("loop", None)) - self._uamqp_transport = kwargs.pop("uamqp_transport", True) + uamqp_transport = kwargs.pop("uamqp_transport", True) self._amqp_transport = UamqpTransportAsync if isinstance(credential, AzureSasCredential): self._credential = EventhubAzureSasTokenCredentialAsync(credential) # type: ignore @@ -222,7 +222,7 @@ def __init__( fully_qualified_namespace=fully_qualified_namespace, eventhub_name=eventhub_name, credential=self._credential, - uamqp_transport=self._uamqp_transport, + uamqp_transport=uamqp_transport, amqp_transport=self._amqp_transport, **kwargs ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py index ab91e4939967..e10a951695bc 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py @@ -3,17 +3,17 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations from typing import TYPE_CHECKING from asyncio import Lock -from uamqp import c_uamqp -from uamqp.async_ops import ConnectionAsync - +from ._transport._uamqp_transport_async import UamqpTransportAsync from .._connection_manager import _ConnectionMode from .._constants import TransportType if TYPE_CHECKING: from uamqp.authentication import JWTTokenAsync + from uamqp.async_ops import ConnectionAsync try: from typing_extensions import Protocol @@ -22,7 +22,7 @@ class ConnectionManager(Protocol): async def get_connection( - self, host: str, auth: "JWTTokenAsync" + self, host: str, auth: JWTTokenAsync ) -> ConnectionAsync: pass @@ -52,11 +52,12 @@ def __init__(self, **kwargs) -> None: self._remote_idle_timeout_empty_frame_send_ratio = kwargs.get( "remote_idle_timeout_empty_frame_send_ratio" ) + self._amqp_transport = kwargs.get("amqp_transport", UamqpTransportAsync) - async def get_connection(self, host: str, auth: "JWTTokenAsync") -> ConnectionAsync: + async def get_connection(self, host: str, auth: JWTTokenAsync) -> ConnectionAsync: async with self._lock: if self._conn is None: - self._conn = ConnectionAsync( + self._conn = self._amqp_transport.create_connection_async( host, auth, container_id=self._container_id, @@ -75,17 +76,13 @@ async def get_connection(self, host: str, auth: "JWTTokenAsync") -> ConnectionAs async def close_connection(self) -> None: async with self._lock: if self._conn: - await self._conn.destroy_async() + await self._amqp_transport.close_connection_async(self._conn) self._conn = None async def reset_connection_if_broken(self) -> None: async with self._lock: - if self._conn and self._conn._state in ( # pylint:disable=protected-access - c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member - c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member - c_uamqp.ConnectionState.DISCARDING, # pylint:disable=c-extension-no-member - c_uamqp.ConnectionState.END, # pylint:disable=c-extension-no-member - ): + conn_state = self._amqp_transport.get_connection_state(self._conn) + if self._conn and conn_state in self._amqp_transport.CONNECTION_CLOSING_STATES: self._conn = None @@ -93,7 +90,7 @@ class _SeparateConnectionManager(object): def __init__(self, **kwargs) -> None: pass - async def get_connection(self, host: str, auth: "JWTTokenAsync") -> None: + async def get_connection(self, host: str, auth: JWTTokenAsync) -> None: pass # return None async def close_connection(self) -> None: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py index 91c0d1a29f9d..65a1e8ff981e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py @@ -93,9 +93,8 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs) -> N self._timeout = 0 self._idle_timeout = (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) if idle_timeout else None link_properties: Dict[types.AMQPType, types.AMQPType] = {} - partition = self._source.split("/")[-1] - self._partition = partition - self._name = f"EHReceiver-{uuid.uuid4()}-partition{partition}" + self._partition = self._source.split("/")[-1] + self._name = f"EHReceiver-{uuid.uuid4()}-partition{self._partition}" if owner_level is not None: link_properties[EPOCH_SYMBOL] = int(owner_level) link_property_timeout_ms = ( diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py index 6fb84299bfc2..cef98a1c87d3 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py @@ -99,7 +99,7 @@ def __init__(self, client: EventHubProducerClient, target: str, **kwargs) -> Non self._condition: Optional[Exception] = None self._lock = asyncio.Lock(**self._internal_kwargs) self._link_properties = self._amqp_transport.create_link_properties( - {TIMEOUT_SYMBOL: int(self._timeout * 1000)} + {TIMEOUT_SYMBOL: int(self._timeout * self._amqp_transport.TIMEOUT_FACTOR)} ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py index bfe03aa365aa..290a36641437 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py @@ -277,7 +277,7 @@ async def _buffered_send_batch(self, event_data_batch, **kwargs): async def _buffered_send_event(self, event, **kwargs): partition_key = kwargs.get("partition_key") - set_event_partition_key(event, partition_key) + set_event_partition_key(event, partition_key, self._amqp_transport) timeout = kwargs.get("timeout") timeout_time = time.time() + timeout if timeout else None await self._buffered_send( diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py index b495f2028dc7..61842273c2dc 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py @@ -73,6 +73,41 @@ def create_link_properties(link_properties): :rtype: dict """ + @staticmethod + @abstractmethod + async def create_connection_async(**kwargs): + """ + Creates and returns the uamqp async Connection object. + :keyword str host: The hostname, used by uamqp. + :keyword JWTTokenAuth auth: The auth, used by uamqp. + :keyword str endpoint: The endpoint, used by pyamqp. + :keyword str container_id: Required. + :keyword int max_frame_size: Required. + :keyword int channel_max: Required. + :keyword int idle_timeout: Required. + :keyword Dict properties: Required. + :keyword int remote_idle_timeout_empty_frame_send_ratio: Required. + :keyword error_policy: Required. + :keyword bool debug: Required. + :keyword str encoding: Required. + """ + + @staticmethod + @abstractmethod + async def close_connection_async(connection): + """ + Closes existing connection. + :param connection: uamqp or pyamqp Connection. + """ + + @staticmethod + @abstractmethod + def get_connection_state(connection): + """ + Gets connection state. + :param connection: uamqp or pyamqp Connection. + """ + @staticmethod @abstractmethod def create_send_client(*, config, **kwargs): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py index 61fed458e44a..1ae98e3f7df8 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py @@ -19,6 +19,7 @@ AMQPClientAsync, errors, ) +from uamqp.async_ops import ConnectionAsync from ._base_async import AmqpTransportAsync from ..._transport._uamqp_transport import UamqpTransport @@ -40,6 +41,39 @@ class UamqpTransportAsync(UamqpTransport, AmqpTransportAsync): Class which defines uamqp-based methods used by the producer and consumer. """ + @staticmethod + async def create_connection_async(**kwargs): + """ + Creates and returns the uamqp async Connection object. + :keyword str host: The hostname, used by uamqp. + :keyword JWTTokenAuth auth: The auth, used by uamqp. + :keyword str endpoint: The endpoint, used by pyamqp. + :keyword str container_id: Required. + :keyword int max_frame_size: Required. + :keyword int channel_max: Required. + :keyword int idle_timeout: Required. + :keyword Dict properties: Required. + :keyword int remote_idle_timeout_empty_frame_send_ratio: Required. + :keyword error_policy: Required. + :keyword bool debug: Required. + :keyword str encoding: Required. + """ + endpoint = kwargs.pop("endpoint") # pylint:disable=unused-variable + host = kwargs.pop("host") + auth = kwargs.pop("auth") + return ConnectionAsync( + host, + auth, + **kwargs + ) + + @staticmethod + async def close_connection_async(connection): + """ + Closes existing connection. + :param connection: uamqp or pyamqp Connection. + """ + @staticmethod def create_send_client(*, config, **kwargs): # pylint:disable=unused-argument """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py index c883a55f5cc3..d2779e0a6cfc 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py @@ -5,7 +5,7 @@ # ------------------------------------------------------------------------- from __future__ import annotations -from typing import Optional, Any, cast, Mapping, Dict, Union +from typing import Optional, Any, cast, Mapping, Dict, Union, List from ._amqp_utils import normalized_data_body, normalized_sequence_body from ._constants import AmqpMessageBodyType @@ -46,9 +46,9 @@ class AmqpAnnotatedMessage(object): def __init__(self, **kwargs): # type: (Any) -> None self._encoding = kwargs.pop("encoding", "UTF-8") - self._data_body = None - self._sequence_body = None - self._value_body = None + self._data_body: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None + self._sequence_body: Optional[List[Any]] = None + self._value_body: Any = None # internal usage only for Event Hub received message message = kwargs.pop("message", None) @@ -70,7 +70,7 @@ def __init__(self, **kwargs): "or value_body being set as the body of the AmqpAnnotatedMessage." ) - self._body_type = None + self._body_type: AmqpMessageBodyType = None # type: ignore if "data_body" in kwargs: self._data_body = normalized_data_body(kwargs.get("data_body")) self._body_type = AmqpMessageBodyType.DATA @@ -94,7 +94,7 @@ def __init__(self, **kwargs): def __str__(self) -> str: if self._body_type == AmqpMessageBodyType.DATA: - return "".join(d.decode(self._encoding) for d in self._data_body) + return "".join(d.decode(self._encoding) for d in self._data_body) # type: ignore if self._body_type == AmqpMessageBodyType.SEQUENCE: return str(self._sequence_body) if self._body_type == AmqpMessageBodyType.VALUE: @@ -163,10 +163,10 @@ def _from_amqp_message(self, message): self._delivery_annotations = message.delivery_annotations if message.delivery_annotations else {} self._application_properties = message.application_properties if message.application_properties else {} if message.data: - self._data_body = list(message.data) + self._data_body = cast(List, list(message.data)) self._body_type = AmqpMessageBodyType.DATA elif message.sequence: - self._sequence_body = list(message.sequence) + self._sequence_body = cast(List, list(message.sequence)) self._body_type = AmqpMessageBodyType.SEQUENCE else: self._value_body = message.value @@ -181,7 +181,7 @@ def body(self) -> Any: :rtype: Any """ if self._body_type == AmqpMessageBodyType.DATA: # pylint:disable=no-else-return - return (i for i in self._data_body) + return (i for i in self._data_body) # type: ignore elif self._body_type == AmqpMessageBodyType.SEQUENCE: return (i for i in self._sequence_body) elif self._body_type == AmqpMessageBodyType.VALUE: From e19a0d96b43878a3e645b6fc85b09b0f2aa3644a Mon Sep 17 00:00:00 2001 From: swathipil Date: Mon, 15 Aug 2022 10:17:11 -0700 Subject: [PATCH 12/20] Annas comments --- .../azure/eventhub/_client_base.py | 39 +++------- .../azure-eventhub/azure/eventhub/_common.py | 8 +- .../azure/eventhub/_connection_manager.py | 6 +- .../azure/eventhub/_consumer.py | 12 +-- .../azure/eventhub/_transport/_base.py | 50 +++++++++---- .../eventhub/_transport/_uamqp_transport.py | 74 +++++++++++++++---- .../azure/eventhub/aio/_client_base_async.py | 35 +++------ .../eventhub/aio/_connection_manager_async.py | 11 ++- .../eventhub/aio/_transport/_base_async.py | 45 +++++++---- .../aio/_transport/_uamqp_transport_async.py | 45 +++++------ 10 files changed, 193 insertions(+), 132 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index bf9219c1729e..b8e9dec09bd9 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -395,7 +395,10 @@ def _management_request( self._address, mgmt_auth=mgmt_auth, config=self._config ) try: - mgmt_client.open() + conn = self._conn_manager.get_connection( # pylint:disable=assignment-from-none + host=self._address.hostname, auth=mgmt_auth + ) + mgmt_client.open(connection=conn) while not mgmt_client.client_ready(): time.sleep(0.05) mgmt_msg.application_properties[ @@ -417,22 +420,7 @@ def _management_request( description = description.decode("utf-8") if status_code < 400: return response - if status_code in [401]: - raise self._amqp_transport.get_error( - self._amqp_transport.AUTH_EXCEPTION, - f"Management authentication failed. Status code: {status_code}, Description: {description!r}" - ) - if status_code in [ - 404 - ]: - return self._amqp_transport.get_error( - self._amqp_transport.CONNECTION_ERROR, - f"Management connection failed. Status code: {status_code}, Description: {description!r}" - ) - return self._amqp_transport.get_error( - self._amqp_transport.AMQP_CONNECTION_ERROR, - f"Management request error. Status code: {status_code}, Description: {description!r}" - ) + raise self._amqp_transport.get_error(status_code, description) except Exception as exception: # pylint: disable=broad-except last_exception = self._amqp_transport._handle_exception( # pylint: disable=protected-access exception, self @@ -456,7 +444,7 @@ def _add_span_request_attributes(self, span): span.add_attribute("peer.address", self._address.hostname) def _get_eventhub_properties(self) -> Dict[str, Any]: - mgmt_msg = self._amqp_transport.MESSAGE( + mgmt_msg = self._amqp_transport.build_message( application_properties={"name": self.eventhub_name} ) response = self._management_request(mgmt_msg, op_type=MGMT_OPERATION) @@ -478,7 +466,7 @@ def _get_partition_ids(self): def _get_partition_properties(self, partition_id): # type:(str) -> Dict[str, Any] - mgmt_msg = self._amqp_transport.MESSAGE( + mgmt_msg = self._amqp_transport.build_message( application_properties={ "name": self.eventhub_name, "partition": partition_id, @@ -534,7 +522,10 @@ def _open(self): self._handler.close() auth = self._client._create_auth() self._create_handler(auth) - self._handler.open() + conn = self._client._conn_manager.get_connection( # pylint: disable=protected-access + host=self._client._address.hostname, auth=auth + ) + self._handler.open(connection=conn) while not self._handler.client_ready(): time.sleep(0.05) self._max_message_size_on_link = ( @@ -553,13 +544,7 @@ def _close_connection(self): self._client._conn_manager.reset_connection_if_broken() # pylint: disable=protected-access def _handle_exception(self, exception): - if not self.running and isinstance( - exception, self._amqp_transport.TIMEOUT_EXCEPTION - ): - exception = self._amqp_transport.get_error( - self._amqp_transport.AUTH_EXCEPTION, - "Authorization timeout." - ) + exception = self._amqp_transport.check_timeout_exception(self, exception) return self._amqp_transport._handle_exception( # pylint: disable=protected-access exception, self ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 76f553a68e9f..d7dad5a6156d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -21,8 +21,6 @@ ) from typing_extensions import TypedDict -import six - from ._utils import ( trace_message, utc_from_timestamp, @@ -389,7 +387,7 @@ def body_as_str(self, encoding: str = "UTF-8") -> str: return self._decode_non_data_body_as_str(encoding=encoding) return "".join(b.decode(encoding) for b in cast(Iterable[bytes], data)) except TypeError: - return six.text_type(data) + return str(data) except: # pylint: disable=bare-except pass try: @@ -512,7 +510,7 @@ def __init__( if partition_key and not isinstance( - partition_key, (six.text_type, six.binary_type) + partition_key, (str, bytes) ): _LOGGER.info( "WARNING: Setting partition_key of non-string value on the events to be sent is discouraged " @@ -524,7 +522,7 @@ def __init__( self.max_size_in_bytes = ( max_size_in_bytes or self._amqp_transport.MAX_FRAME_SIZE_BYTES ) - self._message = self._amqp_transport.BATCH_MESSAGE(data=[]) + self._message = self._amqp_transport.build_batch_message(data=[]) self._partition_id = partition_id self._partition_key = partition_key diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py index 016fed03d4f4..2ccd0fe80e07 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------------------------- from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from threading import Lock from enum import Enum @@ -57,7 +57,9 @@ def __init__(self, **kwargs): ) self._amqp_transport = kwargs.get("amqp_transport", UamqpTransport) - def get_connection(self, *, host: str, auth: JWTTokenAuth, endpoint: str) -> Connection: + def get_connection( + self, *, host: Optional[str] = None, auth: Optional[JWTTokenAuth] = None, endpoint: Optional[str] = None + ) -> Connection: with self._lock: if self._conn is None: self._conn = self._amqp_transport.create_connection( diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index 1cc6656c6604..3aa6423fdd38 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -131,6 +131,7 @@ def _create_handler(self, auth: uamqp_JWTTokenAuth) -> None: network_trace=self._client._config.network_tracing, # pylint:disable=protected-access link_credit=self._prefetch, link_properties=self._link_properties, + timeout=self._timeout, idle_timeout=self._idle_timeout, retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, @@ -166,7 +167,10 @@ def _open(self) -> bool: self._handler.close() auth = self._client._create_auth() self._create_handler(auth) - self._handler.open() + conn = self._client._conn_manager.get_connection( # pylint: disable=protected-access + host=self._client._address.hostname, auth=auth + ) + self._handler.open(connection=conn) while not self._handler.client_ready(): time.sleep(0.05) self.handler_ready = True @@ -190,11 +194,7 @@ def receive(self, batch=False, max_batch_size=300, max_wait_time=None): self._handler.do_work(batch=self._prefetch) # type: ignore break except Exception as exception: # pylint: disable=broad-except - if ( - isinstance(exception, self._amqp_transport.AMQP_LINK_ERROR) - and exception.condition == self._amqp_transport.LINK_STOLEN_CONDITION # pylint: disable=no-member - ): - raise self._handle_exception(exception) + self._amqp_transport.check_link_stolen(self, exception) if not self.running: # exit by close return if self._last_received_event: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py index 290503e171a5..8fff57dc543d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py @@ -9,10 +9,8 @@ class AmqpTransport(ABC): Abstract class that defines a set of common methods needed by producer and consumer. """ # define constants - BATCH_MESSAGE = None MAX_FRAME_SIZE_BYTES = None IDLE_TIMEOUT_FACTOR = None - MESSAGE = None CONNECTION_CLOSING_STATES = None # define symbols @@ -23,12 +21,21 @@ class AmqpTransport(ABC): USER_AGENT_SYMBOL = None PROP_PARTITION_KEY_AMQP_SYMBOL = None - # errors - AMQP_LINK_ERROR = None - LINK_STOLEN_CONDITION = None - MGMT_AUTH_EXCEPTION = None - CONNECTION_ERROR = None - AMQP_CONNECTION_ERROR = None + @staticmethod + @abstractmethod + def build_message(**kwargs): + """ + Creates a uamqp.Message or pyamqp.Message with given arguments. + :rtype: uamqp.Message or pyamqp.Message + """ + + @staticmethod + @abstractmethod + def build_batch_message(**kwargs): + """ + Creates a uamqp.BatchMessage or pyamqp.BatchMessage with given arguments. + :rtype: uamqp.BatchMessage or pyamqp.BatchMessage + """ @staticmethod @abstractmethod @@ -202,6 +209,15 @@ def open_receive_client(*, handler, client, auth): :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. """ + @staticmethod + @abstractmethod + def check_link_stolen(consumer, exception): + """ + Checks if link stolen and handles exception. + :param consumer: The EventHubConsumer. + :param exception: Exception to check. + """ + @staticmethod @abstractmethod def create_token_auth(auth_uri, get_token, token_type, config, **kwargs): @@ -250,10 +266,18 @@ def mgmt_client_request(mgmt_client, mgmt_msg, **kwargs): @staticmethod @abstractmethod - def get_error(error, message, *, condition=None): + def get_error(status_code, description): + """ + Gets error corresponding to status code. + :param status_code: Status code. + :param str description: Description of error. + """ + + @staticmethod + @abstractmethod + def check_timeout_exception(base, exception): """ - Gets error and passes in error message, and, if applicable, condition. - :param error: The error to raise. - :param str message: Error message. - :param condition: Optional error condition. Will not be used by uamqp. + Checks if timeout exception. + :param base: ClientBase. + :param exception: Exception to check. """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index 24c4c80598b0..d90011dbe664 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -82,10 +82,8 @@ class UamqpTransport(AmqpTransport): Class which defines uamqp-based methods used by the producer and consumer. """ # define constants - BATCH_MESSAGE = BatchMessage MAX_FRAME_SIZE_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES IDLE_TIMEOUT_FACTOR = 1000 - MESSAGE = Message CONNECTION_CLOSING_STATES = ( # pylint:disable=protected-access c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member @@ -101,13 +99,21 @@ class UamqpTransport(AmqpTransport): USER_AGENT_SYMBOL = types.AMQPSymbol("user-agent") PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) - # define errors and conditions - AMQP_LINK_ERROR = errors.LinkDetach - LINK_STOLEN_CONDITION = constants.ErrorCodes.LinkStolen - AUTH_EXCEPTION = errors.AuthenticationException - CONNECTION_ERROR = ConnectError - AMQP_CONNECTION_ERROR = errors.AMQPConnectionError - TIMEOUT_EXCEPTION = compat.TimeoutException + @staticmethod + def build_message(**kwargs): + """ + Creates a uamqp.Message or pyamqp.Message with given arguments. + :rtype: uamqp.Message or pyamqp.Message + """ + return Message(**kwargs) + + @staticmethod + def build_batch_message(**kwargs): + """ + Creates a uamqp.BatchMessage or pyamqp.BatchMessage with given arguments. + :rtype: uamqp.BatchMessage or pyamqp.BatchMessage + """ + return BatchMessage(**kwargs) @staticmethod def to_outgoing_amqp_message(annotated_message): @@ -434,6 +440,19 @@ def open_receive_client(*, handler, client, auth): client._address.hostname, auth )) + @staticmethod + def check_link_stolen(consumer, exception): + """ + Checks if link stolen and handles exception. + :param consumer: The EventHubConsumer. + :param exception: Exception to check. + """ + if ( + isinstance(exception, errors.LinkDetach) + and exception.condition == constants.ErrorCodes.LinkStolen # pylint: disable=no-member + ): + raise consumer._handle_exception(exception) # pylint: disable=protected-access + @staticmethod def create_token_auth(auth_uri, get_token, token_type, config, **kwargs): """ @@ -514,14 +533,39 @@ def mgmt_client_request(mgmt_client, mgmt_msg, **kwargs): ) @staticmethod - def get_error(error, message, *, condition=None): # pylint: disable=unused-argument + def get_error(status_code, description): + """ + Gets error corresponding to status code. + :param status_code: Status code. + :param str description: Description of error. + """ + if status_code in [401]: + return errors.AuthenticationException( + f"Management authentication failed. Status code: {status_code}, Description: {description!r}" + ) + if status_code in [404]: + return ConnectError( + f"Management connection failed. Status code: {status_code}, Description: {description!r}" + ) + return errors.AMQPConnectionError( + f"Management request error. Status code: {status_code}, Description: {description!r}" + ) + + @staticmethod + def check_timeout_exception(base, exception): """ - Gets error and passes in error message, and, if applicable, condition. - :param error: The error to raise. - :param str message: Error message. - :param condition: Optional error condition. Will not be used by uamqp. + Checks if timeout exception. + :param base: ClientBase. + :param exception: Exception to check. """ - return error(message) + if not base.running and isinstance( + exception, compat.TimeoutException + ): + exception = UamqpTransport.get_error( + errors.AuthenticationException, + "Authorization timeout." + ) + return exception @staticmethod def _create_eventhub_exception(exception): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py index ee91735ab607..898bbadc6ec2 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py @@ -316,7 +316,10 @@ async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> self._address, mgmt_auth=mgmt_auth, config=self._config ) try: - await mgmt_client.open_async() + conn = await self._conn_manager_async.get_connection( + host=self._address.hostname, auth=mgmt_auth + ) + await mgmt_client.open_async(connection=conn) while not await mgmt_client.client_ready_async(): await asyncio.sleep(0.05) mgmt_msg.application_properties[ @@ -338,20 +341,7 @@ async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> description = description.decode("utf-8") if status_code < 400: return response - if status_code in [401]: - raise self._amqp_transport.get_error( - self._amqp_transport.AUTH_EXCEPTION, - f"Management authentication failed. Status code: {status_code}, Description: {description!r}" - ) - if status_code in [404]: - return self._amqp_transport.get_error( - self._amqp_transport.CONNECTION_ERROR, - f"Management connection failed. Status code: {status_code}, Description: {description!r}" - ) - return self._amqp_transport.get_error( - self._amqp_transport.AMQP_CONNECTION_ERROR, - f"Management request error. Status code: {status_code}, Description: {description!r}" - ) + raise self._amqp_transport.get_error(status_code, description) except asyncio.CancelledError: # pylint: disable=try-except-raise raise except Exception as exception: # pylint:disable=broad-except @@ -369,7 +359,7 @@ async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> await mgmt_client.close_async() async def _get_eventhub_properties_async(self) -> Dict[str, Any]: - mgmt_msg = mgmt_msg = self._amqp_transport.MESSAGE( + mgmt_msg = mgmt_msg = self._amqp_transport.build_message( application_properties={"name": self.eventhub_name} ) response = await self._management_request_async( @@ -393,7 +383,7 @@ async def _get_partition_ids_async(self) -> List[str]: async def _get_partition_properties_async( self, partition_id: str ) -> Dict[str, Any]: - mgmt_msg = self._amqp_transport.MESSAGE( + mgmt_msg = self._amqp_transport.build_message( application_properties={ "name": self.eventhub_name, "partition": partition_id, @@ -454,7 +444,10 @@ async def _open(self) -> None: await self._handler.close_async() auth = await self._client._create_auth_async() self._create_handler(auth) - await self._handler.open_async() + conn = await self._conn_manager_async.get_connection( + host=self._address.hostname, auth=auth + ) + await self._handler.open_async(connection=conn) while not await self._handler.client_ready_async(): await asyncio.sleep(0.05, **self._internal_kwargs) self._max_message_size_on_link = ( @@ -474,11 +467,7 @@ async def _close_connection_async(self) -> None: await self._client._conn_manager_async.reset_connection_if_broken() # pylint:disable=protected-access async def _handle_exception(self, exception: Exception) -> Exception: - if not self.running and isinstance(exception, self._amqp_transport.TIMEOUT_EXCEPTION): - exception = self._amqp_transport.get_error( - self._amqp_transport.AUTH_EXCEPTION, - "Authorization timeout." - ) + exception = self._amqp_transport.check_timeout_exception(self, exception) return await self._amqp_transport._handle_exception_async( # pylint: disable=protected-access exception, self ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py index e10a951695bc..d9e524391c6a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------------------------- from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from asyncio import Lock from ._transport._uamqp_transport_async import UamqpTransportAsync @@ -54,12 +54,15 @@ def __init__(self, **kwargs) -> None: ) self._amqp_transport = kwargs.get("amqp_transport", UamqpTransportAsync) - async def get_connection(self, host: str, auth: JWTTokenAsync) -> ConnectionAsync: + async def get_connection( + self, *, host: Optional[str] = None, auth: Optional[JWTTokenAsync] = None, endpoint: Optional[str] = None + ) -> ConnectionAsync: async with self._lock: if self._conn is None: self._conn = self._amqp_transport.create_connection_async( - host, - auth, + host=host, + auth=auth, + endpoint=endpoint, container_id=self._container_id, max_frame_size=self._max_frame_size, channel_max=self._channel_max, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py index 61842273c2dc..2aa7e29cf453 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py @@ -9,10 +9,8 @@ class AmqpTransportAsync(ABC): Abstract class that defines a set of common methods needed by producer and consumer. """ # define constants - BATCH_MESSAGE = None MAX_FRAME_SIZE_BYTES = None IDLE_TIMEOUT_FACTOR = None - MESSAGE = None # define symbols PRODUCT_SYMBOL = None @@ -22,12 +20,21 @@ class AmqpTransportAsync(ABC): USER_AGENT_SYMBOL = None PROP_PARTITION_KEY_AMQP_SYMBOL = None - # errors - AMQP_LINK_ERROR = None - LINK_STOLEN_CONDITION = None - MGMT_AUTH_EXCEPTION = None - CONNECTION_ERROR = None - AMQP_CONNECTION_ERROR = None + @staticmethod + @abstractmethod + def build_message(**kwargs): + """ + Creates a uamqp.Message or pyamqp.Message with given arguments. + :rtype: uamqp.Message or pyamqp.Message + """ + + @staticmethod + @abstractmethod + def build_batch_message(**kwargs): + """ + Creates a uamqp.BatchMessage or pyamqp.BatchMessage with given arguments. + :rtype: uamqp.BatchMessage or pyamqp.BatchMessage + """ @staticmethod @abstractmethod @@ -183,10 +190,10 @@ def create_receive_client(*, config, **kwargs): @staticmethod @abstractmethod - async def receive_messages(handler, batch, max_batch_size, max_wait_time): + async def receive_messages(consumer, batch, max_batch_size, max_wait_time): """ Receives messages, creates events, and returns them by calling the on received callback. - :param ReceiveClient handler: The receive client. + :param ~azure.eventhub.aio.EventHubConsumer consumer: The EventHubConsumer. :param bool batch: If receive batch or single event. :param int max_batch_size: Max batch size. :param int or None max_wait_time: Max wait time. @@ -240,10 +247,18 @@ async def mgmt_client_request_async(mgmt_client, mgmt_msg, **kwargs): @staticmethod @abstractmethod - def get_error(error, message, *, condition=None): + def get_error(status_code, description): + """ + Gets error corresponding to status code. + :param status_code: Status code. + :param str description: Description of error. + """ + + @staticmethod + @abstractmethod + def check_timeout_exception(base, exception): """ - Gets error and passes in error message, and, if applicable, condition. - :param error: The error to raise. - :param str message: Error message. - :param condition: Optional error condition. Will not be used by uamqp. + Checks if timeout exception. + :param base: ClientBase. + :param exception: Exception to check. """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py index 1ae98e3f7df8..1f505af2150d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py @@ -73,6 +73,7 @@ async def close_connection_async(connection): Closes existing connection. :param connection: uamqp or pyamqp Connection. """ + await connection.destroy_async() @staticmethod def create_send_client(*, config, **kwargs): # pylint:disable=unused-argument @@ -175,68 +176,68 @@ def create_receive_client(*, config, **kwargs): # pylint:disable=unused-argument return client @staticmethod - async def receive_messages(handler, batch, max_batch_size, max_wait_time): + async def receive_messages(consumer, batch, max_batch_size, max_wait_time): """ Receives messages, creates events, and returns them by calling the on received callback. - :param ReceiveClient handler: The receive client. + :param ~azure.eventhub.aio.EventHubConsumer consumer: The EventHubConsumer. :param bool batch: If receive batch or single event. :param int max_batch_size: Max batch size. :param int or None max_wait_time: Max wait time. """ # pylint:disable=protected-access max_retries = ( - handler._client._config.max_retries # pylint:disable=protected-access + consumer._client._config.max_retries # pylint:disable=protected-access ) has_not_fetched_once = True # ensure one trip when max_wait_time is very small deadline = time.time() + (max_wait_time or 0) # max_wait_time can be None - while len(handler._message_buffer) < max_batch_size and ( + while len(consumer._message_buffer) < max_batch_size and ( time.time() < deadline or has_not_fetched_once ): retried_times = 0 has_not_fetched_once = False while retried_times <= max_retries: try: - await handler._open() + await consumer._open() await cast( - ReceiveClientAsync, handler._handler + ReceiveClientAsync, consumer._consumer ).do_work_async() # uamqp sleeps 0.05 if none received break except asyncio.CancelledError: # pylint: disable=try-except-raise raise except Exception as exception: # pylint: disable=broad-except if ( - isinstance(exception, UamqpTransportAsync.AMQP_LINK_ERROR) - and exception.condition == UamqpTransportAsync.LINK_STOLEN_CONDITION # pylint: disable=no-member + isinstance(exception, errors.LinkDetach) + and exception.condition == constants.ErrorCodes.LinkStolen # pylint: disable=no-member ): - raise await handler._handle_exception(exception) - if not handler.running: # exit by close + raise await consumer._handle_exception(exception) + if not consumer.running: # exit by close return - if handler._last_received_event: - handler._offset = handler._last_received_event.offset - last_exception = await handler._handle_exception(exception) + if consumer._last_received_event: + consumer._offset = consumer._last_received_event.offset + last_exception = await consumer._handle_exception(exception) retried_times += 1 if retried_times > max_retries: _LOGGER.info( "%r operation has exhausted retry. Last exception: %r.", - handler._name, + consumer._name, last_exception, ) raise last_exception - if handler._message_buffer: - while handler._message_buffer: + if consumer._message_buffer: + while consumer._message_buffer: if batch: events_for_callback: List[EventData] = [] - for _ in range(min(max_batch_size, len(handler._message_buffer))): - events_for_callback.append(handler._next_message_in_buffer()) - await handler._on_event_received(events_for_callback) + for _ in range(min(max_batch_size, len(consumer._message_buffer))): + events_for_callback.append(consumer._next_message_in_buffer()) + await consumer._on_event_received(events_for_callback) else: - await handler._on_event_received(handler._next_message_in_buffer()) + await consumer._on_event_received(consumer._next_message_in_buffer()) elif max_wait_time: if batch: - await handler._on_event_received([]) + await consumer._on_event_received([]) else: - await handler._on_event_received(None) + await consumer._on_event_received(None) @staticmethod async def create_token_auth_async(auth_uri, get_token, token_type, config, **kwargs): From db695be55de9c10832e1443bee8397c1941012fa Mon Sep 17 00:00:00 2001 From: swathipil Date: Mon, 15 Aug 2022 10:48:27 -0700 Subject: [PATCH 13/20] remove kwargs from EventDataBatch --- .../_buffered_producer/_buffered_producer.py | 10 ++++------ .../_buffered_producer_dispatcher.py | 3 --- .../azure/eventhub/_client_base.py | 2 +- .../azure-eventhub/azure/eventhub/_common.py | 16 ++++------------ .../azure-eventhub/azure/eventhub/_constants.py | 1 - .../azure/eventhub/_producer_client.py | 10 ++++------ .../azure/eventhub/_transport/_base.py | 1 + .../eventhub/_transport/_uamqp_transport.py | 3 ++- .../_buffered_producer_async.py | 14 ++++++-------- .../_buffered_producer_dispatcher_async.py | 5 +---- .../azure/eventhub/aio/_client_base_async.py | 2 +- .../azure/eventhub/aio/_producer_client_async.py | 6 ++---- .../azure/eventhub/aio/_transport/_base_async.py | 1 + .../tests/livetest/synctests/test_negative.py | 6 +++--- .../tests/livetest/synctests/test_send.py | 2 +- 15 files changed, 31 insertions(+), 51 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py index 4eb2cc55b1b0..f971eb592f79 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py @@ -32,7 +32,6 @@ def __init__( max_message_size_on_link: int, executor: ThreadPoolExecutor, *, - amqp_transport: AmqpTransport, max_buffer_length: int, max_wait_time: float = 1 ): @@ -50,12 +49,11 @@ def __init__( self._cur_batch: Optional[EventDataBatch] = None self._max_message_size_on_link = max_message_size_on_link self._check_max_wait_time_future = None - self._amqp_transport = amqp_transport self.partition_id = partition_id def start(self): with self._lock: - self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) + self._cur_batch = EventDataBatch(self._max_message_size_on_link) self._running = True if self._max_wait_time: self._last_send_time = time.time() @@ -115,11 +113,11 @@ def put_events(self, events, timeout_time=None): self._buffered_queue.put(self._cur_batch) self._buffered_queue.put(events) # create a new batch for incoming events - self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) + self._cur_batch = EventDataBatch(self._max_message_size_on_link) except ValueError: # add single event exceeds the cur batch size, create new batch self._buffered_queue.put(self._cur_batch) - self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) + self._cur_batch = EventDataBatch(self._max_message_size_on_link) self._cur_batch.add(events) self._cur_buffered_len += new_events_len @@ -186,7 +184,7 @@ def flush(self, timeout_time=None, raise_error=True): break # after finishing flushing, reset cur batch and put it into the buffer self._last_send_time = time.time() - self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) + self._cur_batch = EventDataBatch(self._max_message_size_on_link) _LOGGER.info("Partition %r finished flushing.", self.partition_id) def check_max_wait_time_worker(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py index 3690919e04fa..71f97f15fecd 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py @@ -31,7 +31,6 @@ def __init__( eventhub_name: str, max_message_size_on_link: int, *, - amqp_transport: AmqpTransport, max_buffer_length: int = 1500, max_wait_time: float = 1, executor: Optional[Union[ThreadPoolExecutor, int]] = None @@ -48,7 +47,6 @@ def __init__( self._max_wait_time = max_wait_time self._max_buffer_length = max_buffer_length self._existing_executor = False - self._amqp_transport = amqp_transport if not executor: self._executor = ThreadPoolExecutor() @@ -90,7 +88,6 @@ def enqueue_events( executor=self._executor, max_wait_time=self._max_wait_time, max_buffer_length=self._max_buffer_length, - amqp_transport=self._amqp_transport ) buffered_producer.start() self._buffered_producers[pid] = buffered_producer diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index b8e9dec09bd9..a8dd0c14579b 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -530,7 +530,7 @@ def _open(self): time.sleep(0.05) self._max_message_size_on_link = ( self._amqp_transport.get_remote_max_message_size(self._handler) - or self._amqp_transport.MAX_FRAME_SIZE_BYTES + or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES ) self.running = True diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index d7dad5a6156d..9e0cb03730b8 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -53,7 +53,6 @@ AmqpMessageProperties, ) from ._transport._uamqp_transport import UamqpTransport -from uamqp import types if TYPE_CHECKING: from uamqp import Message as uamqp_Message, BatchMessage as uamqp_BatchMessage @@ -500,14 +499,9 @@ def __init__( self, max_size_in_bytes: Optional[int] = None, partition_id: Optional[str] = None, - partition_key: Optional[Union[str, bytes]] = None, - **kwargs, + partition_key: Optional[Union[str, bytes]] = None ) -> None: - # TODO: this changes API, check with Anna if valid - - # Might need move out message creation to right before sending. - # Might take more time to loop through events and add them all to batch in `send` than in `add` here - self._amqp_transport = kwargs.pop("amqp_transport", UamqpTransport) - + self._amqp_transport = UamqpTransport if partition_key and not isinstance( partition_key, (str, bytes) @@ -520,7 +514,7 @@ def __init__( ) self.max_size_in_bytes = ( - max_size_in_bytes or self._amqp_transport.MAX_FRAME_SIZE_BYTES + max_size_in_bytes or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES ) self._message = self._amqp_transport.build_batch_message(data=[]) self._partition_id = partition_id @@ -558,9 +552,7 @@ def _from_batch( ) for m in batch_data ] - batch_data_instance = cls( - partition_key=partition_key, amqp_transport=amqp_transport - ) + batch_data_instance = cls(partition_key=partition_key) for event_data in outgoing_batch_data: batch_data_instance.add(event_data) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_constants.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_constants.py index de5659411a84..eb8fd4f6198f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_constants.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_constants.py @@ -34,7 +34,6 @@ TIMEOUT_SYMBOL = b"com.microsoft:timeout" RECEIVER_RUNTIME_METRIC_SYMBOL = b"com.microsoft:enable-receiver-runtime-metric" -MAX_MESSAGE_LENGTH_BYTES = 1024 * 1024 MAX_USER_AGENT_LENGTH = 512 ALL_PARTITIONS = "all-partitions" CONTAINER_PREFIX = "eventhub.pysdk-" diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index 76d975d04cf6..efebf97f5fe0 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -21,7 +21,7 @@ from ._client_base import ClientBase from ._producer import EventHubProducer -from ._constants import ALL_PARTITIONS, MAX_MESSAGE_LENGTH_BYTES +from ._constants import ALL_PARTITIONS from ._common import EventDataBatch, EventData from ._buffered_producer import BufferedProducerDispatcher from ._utils import set_event_partition_key @@ -247,8 +247,7 @@ def _buffered_send(self, events, **kwargs): self._max_message_size_on_link, max_wait_time=self._max_wait_time, max_buffer_length=self._max_buffer_length, - executor=self._executor, - amqp_transport=self._amqp_transport + executor=self._executor ) self._buffered_producer_dispatcher.enqueue_events(events, **kwargs) @@ -324,7 +323,7 @@ def _get_max_message_size(self): self._amqp_transport.get_remote_max_message_size( self._producers[ALL_PARTITIONS]._handler # type: ignore ) - or MAX_MESSAGE_LENGTH_BYTES + or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES ) def _start_producer(self, partition_id, send_timeout): @@ -728,8 +727,7 @@ def create_batch(self, **kwargs): event_data_batch = EventDataBatch( max_size_in_bytes=(max_size_in_bytes or self._max_message_size_on_link), partition_id=partition_id, - partition_key=partition_key, - amqp_transport=self._amqp_transport, + partition_key=partition_key ) return event_data_batch diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py index 8fff57dc543d..76144882772d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py @@ -10,6 +10,7 @@ class AmqpTransport(ABC): """ # define constants MAX_FRAME_SIZE_BYTES = None + MAX_MESSAGE_LENGTH_BYTES = None IDLE_TIMEOUT_FACTOR = None CONNECTION_CLOSING_STATES = None diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index d90011dbe664..0ca2eca3d154 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -82,7 +82,8 @@ class UamqpTransport(AmqpTransport): Class which defines uamqp-based methods used by the producer and consumer. """ # define constants - MAX_FRAME_SIZE_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES + MAX_FRAME_SIZE_BYTES = constants.MAX_FRAME_SIZE_BYTES + MAX_MESSAGE_LENGTH_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES IDLE_TIMEOUT_FACTOR = 1000 CONNECTION_CLOSING_STATES = ( # pylint:disable=protected-access c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py index 67fee8dd2a58..93a4b4d345dd 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py @@ -34,8 +34,7 @@ def __init__( max_message_size_on_link: int, *, max_wait_time: float = 1, - max_buffer_length: int, - amqp_transport: AmqpTransportAsync + max_buffer_length: int ): self._buffered_queue: queue.Queue = queue.Queue() self._max_buffer_len = max_buffer_length @@ -50,12 +49,11 @@ def __init__( self._cur_batch: Optional[EventDataBatch] = None self._max_message_size_on_link = max_message_size_on_link self._check_max_wait_time_future = None - self._amqp_transport = amqp_transport self.partition_id = partition_id async def start(self): async with self._lock: - self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) + self._cur_batch = EventDataBatch(self._max_message_size_on_link) self._running = True if self._max_wait_time: self._last_send_time = time.time() @@ -117,11 +115,11 @@ async def put_events(self, events, timeout_time=None): self._buffered_queue.put(self._cur_batch) self._buffered_queue.put(events) # create a new batch for incoming events - self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) + self._cur_batch = EventDataBatch(self._max_message_size_on_link) except ValueError: # add single event exceeds the cur batch size, create new batch self._buffered_queue.put(self._cur_batch) - self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) + self._cur_batch = EventDataBatch(self._max_message_size_on_link) self._cur_batch.add(events) self._cur_buffered_len += new_events_len @@ -149,7 +147,7 @@ async def _flush(self, timeout_time=None, raise_error=True): _LOGGER.info("Partition: %r started flushing.", self.partition_id) if self._cur_batch: # if there is batch, enqueue it to the buffer first self._buffered_queue.put(self._cur_batch) - self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) + self._cur_batch = EventDataBatch(self._max_message_size_on_link) while self._cur_buffered_len: remaining_time = timeout_time - time.time() if timeout_time else None if (remaining_time and remaining_time > 0) or remaining_time is None: @@ -191,7 +189,7 @@ async def _flush(self, timeout_time=None, raise_error=True): break # after finishing flushing, reset cur batch and put it into the buffer self._last_send_time = time.time() - self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) + self._cur_batch = EventDataBatch(self._max_message_size_on_link) _LOGGER.info("Partition %r finished flushing.", self.partition_id) async def check_max_wait_time_worker(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py index 04e5a12ea69f..d3f2135ff170 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py @@ -34,8 +34,7 @@ def __init__( max_message_size_on_link: int, *, max_buffer_length: int = 1500, - max_wait_time: float = 1, - amqp_transport: AmqpTransportAsync, + max_wait_time: float = 1 ): self._buffered_producers: Dict[str, BufferedProducer] = {} self._partition_ids: List[str] = partitions @@ -48,7 +47,6 @@ def __init__( self._partition_resolver = PartitionResolver(self._partition_ids) self._max_wait_time = max_wait_time self._max_buffer_length = max_buffer_length - self._amqp_transport = amqp_transport async def _get_partition_id(self, partition_id, partition_key): if partition_id: @@ -81,7 +79,6 @@ async def enqueue_events( self._max_message_size_on_link, max_wait_time=self._max_wait_time, max_buffer_length=self._max_buffer_length, - amqp_transport=self._amqp_transport, ) await buffered_producer.start() self._buffered_producers[pid] = buffered_producer diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py index 898bbadc6ec2..86788b651a65 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py @@ -452,7 +452,7 @@ async def _open(self) -> None: await asyncio.sleep(0.05, **self._internal_kwargs) self._max_message_size_on_link = ( self._amqp_transport.get_remote_max_message_size(self._handler) - or constants.MAX_FRAME_SIZE_BYTES + or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES ) self.running = True diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py index 290a36641437..c89567ba2cd5 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py @@ -231,7 +231,6 @@ async def _buffered_send(self, events, **kwargs): self._max_message_size_on_link, max_wait_time=self._max_wait_time, max_buffer_length=self._max_buffer_length, - amqp_transport=self._amqp_transport, ) await self._buffered_producer_dispatcher.enqueue_events(events, **kwargs) @@ -306,7 +305,7 @@ async def _get_max_message_size(self) -> None: EventHubProducer, self._producers[ALL_PARTITIONS] )._handler ) - or constants.MAX_MESSAGE_LENGTH_BYTES + or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES ) async def _start_producer( @@ -721,8 +720,7 @@ async def create_batch( event_data_batch = EventDataBatch( max_size_in_bytes=(max_size_in_bytes or self._max_message_size_on_link), partition_id=partition_id, - partition_key=partition_key, - amqp_transport=self._amqp_transport, + partition_key=partition_key ) return event_data_batch diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py index 2aa7e29cf453..58f895220309 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py @@ -10,6 +10,7 @@ class AmqpTransportAsync(ABC): """ # define constants MAX_FRAME_SIZE_BYTES = None + MAX_MESSAGE_LENGTH_BYTES = None IDLE_TIMEOUT_FACTOR = None # define symbols diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py index aae3ce3aafd7..ba58cf2a00c7 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py @@ -34,7 +34,7 @@ def test_send_batch_with_invalid_hostname(invalid_hostname, uamqp_transport): client = EventHubProducerClient.from_connection_string(invalid_hostname, uamqp_transport=uamqp_transport) with client: with pytest.raises(ConnectError): - batch = EventDataBatch(amqp_transport=amqp_transport) + batch = EventDataBatch() batch.add(EventData("test data")) client.send_batch(batch) @@ -47,7 +47,7 @@ def on_error(events, pid, err): on_error.err = None client = EventHubProducerClient.from_connection_string(invalid_hostname, on_error=on_error, uamqp_transport=uamqp_transport) with client: - batch = EventDataBatch(amqp_transport=amqp_transport) + batch = EventDataBatch() batch.add(EventData("test data")) client.send_batch(batch) assert isinstance(on_error.err, ConnectError) @@ -82,7 +82,7 @@ def test_send_batch_with_invalid_key(invalid_key, uamqp_transport): amqp_transport = UamqpTransport if uamqp_transport else None try: with pytest.raises(ConnectError): - batch = EventDataBatch(amqp_transport=amqp_transport) + batch = EventDataBatch() batch.add(EventData("test data")) client.send_batch(batch) finally: diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py index fdc3b640791b..ba120d986abe 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py @@ -478,7 +478,7 @@ def test_send_batch_pid_pk(invalid_hostname, partition_id, partition_key, uamqp_ # Use invalid_hostname because this is not a live test. amqp_transport = UamqpTransport if uamqp_transport else None client = EventHubProducerClient.from_connection_string(invalid_hostname, uamqp_transport=uamqp_transport) - batch = EventDataBatch(partition_id=partition_id, partition_key=partition_key, amqp_transport=amqp_transport) + batch = EventDataBatch(partition_id=partition_id, partition_key=partition_key) with client: with pytest.raises(TypeError): client.send_batch(batch, partition_id=partition_id, partition_key=partition_key) From 792bd782d5f157b058cafb996b55c83b5f778497 Mon Sep 17 00:00:00 2001 From: swathipil Date: Mon, 15 Aug 2022 10:53:26 -0700 Subject: [PATCH 14/20] fix bugs --- sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py | 4 ++-- sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py | 2 +- .../azure-eventhub/azure/eventhub/_transport/_base.py | 2 +- .../azure/eventhub/_transport/_uamqp_transport.py | 2 +- .../azure-eventhub/azure/eventhub/aio/_consumer_async.py | 4 ++-- .../azure-eventhub/azure/eventhub/aio/_producer_async.py | 2 +- .../azure/eventhub/aio/_transport/_base_async.py | 2 +- sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py | 2 +- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index 3aa6423fdd38..8f647a60c6e7 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -98,14 +98,14 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs: Any) link_properties: Dict[uamqp_types.AMQPType, uamqp_types.AMQPType] = {} self._error = None self._timeout = 0 - self._idle_timeout = (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) if idle_timeout else None + self._idle_timeout = (idle_timeout * self._amqp_transport.TIMEOUT_FACTOR) if idle_timeout else None self._partition = self._source.split("/")[-1] self._name = f"EHConsumer-{uuid.uuid4()}-partition{self._partition}" if owner_level is not None: link_properties[EPOCH_SYMBOL] = int(owner_level) link_property_timeout_ms = ( self._client._config.receive_timeout or self._timeout # pylint:disable=protected-access - ) * self._amqp_transport.IDLE_TIMEOUT_FACTOR + ) * self._amqp_transport.TIMEOUT_FACTOR link_properties[TIMEOUT_SYMBOL] = int(link_property_timeout_ms) self._link_properties = self._amqp_transport.create_link_properties(link_properties) self._handler: Optional[uamqp_ReceiveClient] = None diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index 86549929113c..e6101966b9d7 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -104,7 +104,7 @@ def __init__( self._partition = partition self._timeout = send_timeout self._idle_timeout = ( - (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) + (idle_timeout * self._amqp_transport.TIMEOUT_FACTOR) if idle_timeout else None ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py index 76144882772d..641642548fdf 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py @@ -11,7 +11,7 @@ class AmqpTransport(ABC): # define constants MAX_FRAME_SIZE_BYTES = None MAX_MESSAGE_LENGTH_BYTES = None - IDLE_TIMEOUT_FACTOR = None + TIMEOUT_FACTOR = None CONNECTION_CLOSING_STATES = None # define symbols diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index 0ca2eca3d154..e64eb0989edf 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -84,7 +84,7 @@ class UamqpTransport(AmqpTransport): # define constants MAX_FRAME_SIZE_BYTES = constants.MAX_FRAME_SIZE_BYTES MAX_MESSAGE_LENGTH_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES - IDLE_TIMEOUT_FACTOR = 1000 + TIMEOUT_FACTOR = 1000 CONNECTION_CLOSING_STATES = ( # pylint:disable=protected-access c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py index 65a1e8ff981e..c918bca559c8 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py @@ -91,7 +91,7 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs) -> N self._retry_policy = self._amqp_transport.create_retry_policy(self._client._config) self._reconnect_backoff = 1 self._timeout = 0 - self._idle_timeout = (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) if idle_timeout else None + self._idle_timeout = (idle_timeout * self._amqp_transport.TIMEOUT_FACTOR) if idle_timeout else None link_properties: Dict[types.AMQPType, types.AMQPType] = {} self._partition = self._source.split("/")[-1] self._name = f"EHReceiver-{uuid.uuid4()}-partition{self._partition}" @@ -100,7 +100,7 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs) -> N link_property_timeout_ms = ( self._client._config.receive_timeout or self._timeout # pylint:disable=protected-access - ) * self._amqp_transport.IDLE_TIMEOUT_FACTOR + ) * self._amqp_transport.TIMEOUT_FACTOR link_properties[TIMEOUT_SYMBOL] = int(link_property_timeout_ms) self._link_properties = self._amqp_transport.create_link_properties(link_properties) self._handler: Optional[ReceiveClientAsync] = None diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py index cef98a1c87d3..8f0fe7d6004f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py @@ -79,7 +79,7 @@ def __init__(self, client: EventHubProducerClient, target: str, **kwargs) -> Non self._auto_reconnect = auto_reconnect self._timeout = send_timeout self._idle_timeout = ( - (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) + (idle_timeout * self._amqp_transport.TIMEOUT_FACTOR) if idle_timeout else None ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py index 58f895220309..d4193301c728 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py @@ -11,7 +11,7 @@ class AmqpTransportAsync(ABC): # define constants MAX_FRAME_SIZE_BYTES = None MAX_MESSAGE_LENGTH_BYTES = None - IDLE_TIMEOUT_FACTOR = None + TIMEOUT_FACTOR = None # define symbols PRODUCT_SYMBOL = None diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py index cf7982ba4992..37a8b9056758 100644 --- a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py +++ b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py @@ -115,7 +115,7 @@ def test_event_data_batch(uamqp_transport): expected_result = 93 else: pass - batch = EventDataBatch(max_size_in_bytes=110, partition_key="par", amqp_transport=amqp_transport) + batch = EventDataBatch(max_size_in_bytes=110, partition_key="par") batch.add(EventData("A")) assert str(batch) == "EventDataBatch(max_size_in_bytes=110, partition_id=None, partition_key='par', event_count=1)" assert repr(batch) == "EventDataBatch(max_size_in_bytes=110, partition_id=None, partition_key='par', event_count=1)" From 2e4437d82f57b33e8fb9ce75433f5e1c88661836 Mon Sep 17 00:00:00 2001 From: swathipil Date: Mon, 15 Aug 2022 16:34:20 -0700 Subject: [PATCH 15/20] fix lint, mypy, errors --- .../azure/eventhub/_client_base.py | 13 ++------- .../azure/eventhub/_connection_manager.py | 12 ++++---- .../azure/eventhub/_producer.py | 6 ++-- .../azure/eventhub/_transport/_base.py | 27 ++++++++++-------- .../eventhub/_transport/_uamqp_transport.py | 10 +++---- .../azure-eventhub/azure/eventhub/_utils.py | 5 +++- .../azure/eventhub/aio/_client_base_async.py | 28 +++++++++---------- .../eventhub/aio/_connection_manager_async.py | 8 ++++-- .../eventhub/aio/_producer_client_async.py | 1 - .../eventhub/aio/_transport/_base_async.py | 27 +++++++++++------- .../aio/_transport/_uamqp_transport_async.py | 2 +- .../azure/eventhub/amqp/_amqp_message.py | 4 +-- 12 files changed, 77 insertions(+), 66 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index a8dd0c14579b..ea8f460cb353 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -9,9 +9,9 @@ import time import functools import collections -from typing import Any, Dict, Tuple, List, Optional, TYPE_CHECKING, cast, Union +from typing import Any, Dict, Tuple, List, Optional, TYPE_CHECKING, TypeAlias, cast, Union from datetime import timedelta -from urllib.parse import urlparse, quote_plus +from urllib.parse import urlparse from azure.core.credentials import ( AccessToken, @@ -42,13 +42,6 @@ from uamqp import Message as uamqp_Message from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth - CredentialTypes = Union[ - AzureSasCredential, - AzureNamedKeyCredential, - "EventHubSharedKeyCredential", - TokenCredential, - ] - _LOGGER = logging.getLogger(__name__) _Address = collections.namedtuple("_Address", "hostname path") @@ -268,7 +261,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument if TYPE_CHECKING: from azure.core.credentials import TokenCredential - CredentialTypes = Union[ + CredentialTypes: TypeAlias = Union[ AzureSasCredential, AzureNamedKeyCredential, EventHubSharedKeyCredential, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py index 2ccd0fe80e07..f8e109a224cc 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py @@ -21,8 +21,9 @@ Protocol = object # type: ignore class ConnectionManager(Protocol): - def get_connection(self, host, auth): - # type: (str, 'JWTTokenAuth') -> Connection + def get_connection( + self, *, host: Optional[str] = None, auth: Optional[JWTTokenAuth] = None, endpoint: Optional[str] = None + ) -> Connection: pass def close_connection(self): @@ -97,8 +98,9 @@ class _SeparateConnectionManager(object): def __init__(self, **kwargs): pass - def get_connection(self, host, auth): # pylint:disable=unused-argument, no-self-use - # type: (str, JWTTokenAuth) -> None + def get_connection( # pylint:disable=unused-argument, no-self-use + self, *, host: Optional[str] = None, auth: Optional[JWTTokenAuth] = None, endpoint: Optional[str] = None + ) -> None: return None def close_connection(self): @@ -112,7 +114,7 @@ def reset_connection_if_broken(self): def get_connection_manager(**kwargs): # type: (...) -> 'ConnectionManager' - connection_mode = kwargs.get("connection_mode", _ConnectionMode.SeparateConnection) + connection_mode = kwargs.get("connection_mode", _ConnectionMode.SeparateConnection) # type: ignore if connection_mode == _ConnectionMode.ShareConnection: return _SharedConnectionManager(**kwargs) return _SeparateConnectionManager(**kwargs) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index e6101966b9d7..5b42d964400e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -182,7 +182,7 @@ def _send_event_data_with_retry(self, timeout: Optional[float] = None) -> None: def _wrap_eventdata( self, - event_data: Union[EventData, EventDataBatch, Iterable[EventData]], + event_data: Union[EventData, EventDataBatch, Iterable[EventData], AmqpAnnotatedMessage], span: Optional["AbstractSpan"], partition_key: Optional[AnyStr], ) -> Union[EventData, EventDataBatch]: @@ -228,7 +228,7 @@ def _wrap_eventdata( def send( self, - event_data: Union[EventData, EventDataBatch, Iterable[EventData]], + event_data: Union[EventData, EventDataBatch, Iterable[EventData], AmqpAnnotatedMessage], partition_key: Optional[AnyStr] = None, timeout: Optional[float] = None, ) -> None: @@ -237,7 +237,7 @@ def send( received or operation times out. :param event_data: The event to be sent. It can be an EventData object, or iterable of EventData objects - :type event_data: ~azure.eventhub.common.EventData, Iterator, Generator, list + :type event_data: ~azure.eventhub.common.EventData, Iterator, Generator, list or AmqpAnnotatedMessage :param partition_key: With the given partition_key, event data will land to a particular partition of the Event Hub decided by the service. partition_key could be omitted if event_data is of type ~azure.eventhub.EventDataBatch. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py index 641642548fdf..d67cceedcd40 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py @@ -2,25 +2,30 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations +from typing import Tuple, Union, TYPE_CHECKING from abc import ABC, abstractmethod -class AmqpTransport(ABC): +if TYPE_CHECKING: + from uamqp import types as uamqp_types + +class AmqpTransport(ABC): # pylint: disable=too-many-public-methods """ Abstract class that defines a set of common methods needed by producer and consumer. """ # define constants - MAX_FRAME_SIZE_BYTES = None - MAX_MESSAGE_LENGTH_BYTES = None - TIMEOUT_FACTOR = None - CONNECTION_CLOSING_STATES = None + MAX_FRAME_SIZE_BYTES: int + MAX_MESSAGE_LENGTH_BYTES: int + TIMEOUT_FACTOR: int + CONNECTION_CLOSING_STATES: Tuple # define symbols - PRODUCT_SYMBOL = None - VERSION_SYMBOL = None - FRAMEWORK_SYMBOL = None - PLATFORM_SYMBOL = None - USER_AGENT_SYMBOL = None - PROP_PARTITION_KEY_AMQP_SYMBOL = None + PRODUCT_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + VERSION_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + FRAMEWORK_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + PLATFORM_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + USER_AGENT_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + PROP_PARTITION_KEY_AMQP_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] @staticmethod @abstractmethod diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index e64eb0989edf..018d3611aa72 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -5,7 +5,7 @@ import time import logging -from typing import Optional, Union, Any +from typing import Optional, Union, Any, Tuple try: from uamqp import ( @@ -77,7 +77,7 @@ def _error_handler(error): return errors.ErrorAction(retry=True) - class UamqpTransport(AmqpTransport): + class UamqpTransport(AmqpTransport): # pylint: disable=too-many-public-methods """ Class which defines uamqp-based methods used by the producer and consumer. """ @@ -85,7 +85,7 @@ class UamqpTransport(AmqpTransport): MAX_FRAME_SIZE_BYTES = constants.MAX_FRAME_SIZE_BYTES MAX_MESSAGE_LENGTH_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES TIMEOUT_FACTOR = 1000 - CONNECTION_CLOSING_STATES = ( # pylint:disable=protected-access + CONNECTION_CLOSING_STATES: Tuple = ( # pylint:disable=protected-access c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member c_uamqp.ConnectionState.DISCARDING, # pylint:disable=c-extension-no-member @@ -260,7 +260,7 @@ def get_connection_state(connection): Gets connection state. :param connection: uamqp or pyamqp Connection. """ - return connection._state + return connection._state # pylint:disable=protected-access @staticmethod def create_send_client(*, config, **kwargs): # pylint:disable=unused-argument @@ -284,7 +284,7 @@ def create_send_client(*, config, **kwargs): # pylint:disable=unused-argument return SendClient( target, - debug=network_trace, # pylint:disable=protected-access + debug=network_trace, error_policy=retry_policy, **kwargs ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index 7815e71f121b..fc9b1bc8c9c9 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -17,6 +17,7 @@ import time from typing import ( TYPE_CHECKING, + cast, Type, Optional, Dict, @@ -100,7 +101,7 @@ def create_properties( :rtype: dict """ - properties = {} + properties: Dict[Any, str] = {} properties[amqp_transport.PRODUCT_SYMBOL] = USER_AGENT_PREFIX properties[amqp_transport.VERSION_SYMBOL] = VERSION framework = f"Python/{sys.version_info[0]}.{sys.version_info[1]}.{sys.version_info[2]}" @@ -295,6 +296,7 @@ def transform_outbound_single_message(message, message_type, to_outgoing_amqp_me try: # pylint: disable=protected-access # If EventData, set EventData._message to uamqp/pyamqp.Message right before sending. + message = cast("EventData", message) message._message = to_outgoing_amqp_message(message.raw_amqp_message) return message # type: ignore except AttributeError: @@ -302,6 +304,7 @@ def transform_outbound_single_message(message, message_type, to_outgoing_amqp_me # If AmqpAnnotatedMessage, create EventData object with _from_message. # event_data._message will be set to outgoing uamqp/pyamqp.Message. # event_data.raw_amqp_message will be set to AmqpAnnotatedMessage. + message = cast(AmqpAnnotatedMessage, message) amqp_message = to_outgoing_amqp_message(message) return message_type._from_message( message=amqp_message, raw_amqp_message=message # type: ignore diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py index 86788b651a65..e9becf77a884 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py @@ -10,13 +10,6 @@ import functools from typing import TYPE_CHECKING, Any, Dict, List, Callable, Optional, Union, cast -import six -from uamqp import ( - authentication, - constants, - Message, - AMQPClientAsync, -) from azure.core.credentials import ( AccessToken, AzureSasCredential, @@ -44,6 +37,11 @@ from ._transport._uamqp_transport_async import UamqpTransportAsync if TYPE_CHECKING: + from uamqp import ( + authentication, + Message, + AMQPClientAsync, + ) from azure.core.credentials_async import AsyncTokenCredential CredentialTypes = Union[ @@ -337,7 +335,7 @@ async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> description = response.application_properties.get( MGMT_STATUS_DESC ) # type: Optional[Union[str, bytes]] - if description and isinstance(description, six.binary_type): + if description and isinstance(description, bytes): description = description.decode("utf-8") if status_code < 400: return response @@ -444,15 +442,16 @@ async def _open(self) -> None: await self._handler.close_async() auth = await self._client._create_auth_async() self._create_handler(auth) - conn = await self._conn_manager_async.get_connection( - host=self._address.hostname, auth=auth + conn = await self._client._conn_manager_async.get_connection( + host=self._client._address.hostname, auth=auth ) await self._handler.open_async(connection=conn) while not await self._handler.client_ready_async(): await asyncio.sleep(0.05, **self._internal_kwargs) + # pylint: disable=protected-access self._max_message_size_on_link = ( - self._amqp_transport.get_remote_max_message_size(self._handler) - or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES + self._client._amqp_transport.get_remote_max_message_size(self._handler) + or self._client._amqp_transport.MAX_MESSAGE_LENGTH_BYTES ) self.running = True @@ -467,8 +466,9 @@ async def _close_connection_async(self) -> None: await self._client._conn_manager_async.reset_connection_if_broken() # pylint:disable=protected-access async def _handle_exception(self, exception: Exception) -> Exception: - exception = self._amqp_transport.check_timeout_exception(self, exception) - return await self._amqp_transport._handle_exception_async( # pylint: disable=protected-access + # pylint: disable=protected-access + exception = self._client._amqp_transport.check_timeout_exception(self, exception) + return await self._client._amqp_transport._handle_exception_async( exception, self ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py index d9e524391c6a..ec3c0fffaebf 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py @@ -22,7 +22,7 @@ class ConnectionManager(Protocol): async def get_connection( - self, host: str, auth: JWTTokenAsync + self, *, host: Optional[str] = None, auth: Optional[JWTTokenAsync] = None, endpoint: Optional[str] = None ) -> ConnectionAsync: pass @@ -93,7 +93,9 @@ class _SeparateConnectionManager(object): def __init__(self, **kwargs) -> None: pass - async def get_connection(self, host: str, auth: JWTTokenAsync) -> None: + async def get_connection( + self, *, host: Optional[str] = None, auth: Optional[JWTTokenAsync] = None, endpoint: Optional[str] = None + ) -> None: pass # return None async def close_connection(self) -> None: @@ -104,7 +106,7 @@ async def reset_connection_if_broken(self) -> None: def get_connection_manager(**kwargs) -> "ConnectionManager": - connection_mode = kwargs.get("connection_mode", _ConnectionMode.SeparateConnection) + connection_mode = kwargs.get("connection_mode", _ConnectionMode.SeparateConnection) # type: ignore if connection_mode == _ConnectionMode.ShareConnection: return _SharedConnectionManager(**kwargs) return _SeparateConnectionManager(**kwargs) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py index c89567ba2cd5..af59da1efc5a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py @@ -8,7 +8,6 @@ from typing import Any, Union, List, Optional, Dict, Callable, cast from typing_extensions import TYPE_CHECKING, Literal, Awaitable, overload -from uamqp import constants from ..exceptions import ConnectError, EventHubError from ..amqp import AmqpAnnotatedMessage diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py index d4193301c728..ce9342c607ed 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py @@ -2,24 +2,31 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations +from typing import Tuple, Union, TYPE_CHECKING from abc import ABC, abstractmethod -class AmqpTransportAsync(ABC): +if TYPE_CHECKING: + from uamqp import types as uamqp_types + +class AmqpTransportAsync(ABC): # pylint: disable=too-many-public-methods """ Abstract class that defines a set of common methods needed by producer and consumer. """ # define constants - MAX_FRAME_SIZE_BYTES = None - MAX_MESSAGE_LENGTH_BYTES = None - TIMEOUT_FACTOR = None + MAX_FRAME_SIZE_BYTES: int + MAX_MESSAGE_LENGTH_BYTES: int + TIMEOUT_FACTOR: int + CONNECTION_CLOSING_STATES: Tuple # define symbols - PRODUCT_SYMBOL = None - VERSION_SYMBOL = None - FRAMEWORK_SYMBOL = None - PLATFORM_SYMBOL = None - USER_AGENT_SYMBOL = None - PROP_PARTITION_KEY_AMQP_SYMBOL = None + PRODUCT_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + VERSION_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + FRAMEWORK_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + PLATFORM_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + USER_AGENT_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + PROP_PARTITION_KEY_AMQP_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + @staticmethod @abstractmethod diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py index 1f505af2150d..72b91aee7766 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py @@ -199,7 +199,7 @@ async def receive_messages(consumer, batch, max_batch_size, max_wait_time): try: await consumer._open() await cast( - ReceiveClientAsync, consumer._consumer + ReceiveClientAsync, consumer._handler ).do_work_async() # uamqp sleeps 0.05 if none received break except asyncio.CancelledError: # pylint: disable=try-except-raise diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py index d2779e0a6cfc..d6cc5937a11d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py @@ -181,9 +181,9 @@ def body(self) -> Any: :rtype: Any """ if self._body_type == AmqpMessageBodyType.DATA: # pylint:disable=no-else-return - return (i for i in self._data_body) # type: ignore + return (i for i in cast(List, self._data_body)) # type: ignore elif self._body_type == AmqpMessageBodyType.SEQUENCE: - return (i for i in self._sequence_body) + return (i for i in cast(List, self._sequence_body)) elif self._body_type == AmqpMessageBodyType.VALUE: return self._value_body return None From ef684504de804da6a536a41db60589d9815914e9 Mon Sep 17 00:00:00 2001 From: swathipil Date: Tue, 16 Aug 2022 12:37:27 -0700 Subject: [PATCH 16/20] update tests to take uamqp TransportType as well --- sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py | 2 +- .../azure-eventhub/tests/livetest/asynctests/test_send_async.py | 2 +- .../azure-eventhub/tests/livetest/synctests/test_send.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py index 06d3e3a21c7a..00c03ca4197b 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py @@ -38,7 +38,7 @@ def __init__(self, **kwargs): self.hostname = kwargs.pop("hostname") uamqp_transport = kwargs.pop("uamqp_transport") - if self.http_proxy or self.transport_type == TransportType.AmqpOverWebsocket: + if self.http_proxy or self.transport_type.value == TransportType.AmqpOverWebsocket.value: self.transport_type = TransportType.AmqpOverWebsocket self.connection_port = DEFAULT_AMQP_WSS_PORT if not uamqp_transport: diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py index 255818a2511c..141d2b4861d6 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py @@ -343,7 +343,7 @@ async def test_send_multiple_partition_with_app_prop_async(connstr_receivers): async def test_send_over_websocket_async(connstr_receivers): connection_str, receivers = connstr_receivers client = EventHubProducerClient.from_connection_string(connection_str, - transport_type=TransportType.AmqpOverWebsocket) + transport_type=uamqp.constants.TransportType.AmqpOverWebsocket) async with client: batch = await client.create_batch(partition_id="0") diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py index ba120d986abe..d75b4d013470 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py @@ -389,7 +389,7 @@ def test_send_over_websocket_sync(connstr_receivers, uamqp_transport, timeout_fa timeout = 10 * timeout_factor connection_str, receivers = connstr_receivers client = EventHubProducerClient.from_connection_string( - connection_str, transport_type=TransportType.AmqpOverWebsocket, uamqp_transport=uamqp_transport + connection_str, transport_type=uamqp.constants.TransportType.AmqpOverWebsocket, uamqp_transport=uamqp_transport ) with client: From b2a5914265843990bb5405d7a90502e770eed18a Mon Sep 17 00:00:00 2001 From: swathipil Date: Wed, 17 Aug 2022 09:16:45 -0700 Subject: [PATCH 17/20] update changelog --- sdk/eventhub/azure-eventhub/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index 2299880510f9..2906138f1d64 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -7,6 +7,7 @@ This version and all future versions will require Python 3.7+, Python 3.6 is no ### Bugs Fixed - Fixed a bug in `BufferedProducer` that would block when flushing the queue causing the client to freeze up (issue #23510). +- Fixed a bug in the async `EventHubProducerClient` and `EventHubConsumerClient` that set the default value of the `transport_type` parameter in the constructors to `None` rather than `TransportType.Amqp`. ### Other Changes From 40933b3a3286a58a1f690a6f8117ab41b4119613 Mon Sep 17 00:00:00 2001 From: swathipil Date: Thu, 18 Aug 2022 12:50:15 -0700 Subject: [PATCH 18/20] update uamqp min dep + release date --- sdk/eventhub/azure-eventhub/CHANGELOG.md | 3 ++- sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py | 6 +++++- sdk/eventhub/azure-eventhub/dev_requirements.txt | 2 +- sdk/eventhub/azure-eventhub/setup.py | 2 +- shared_requirements.txt | 4 ++-- 5 files changed, 11 insertions(+), 6 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index 2906138f1d64..07fe891132ae 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -1,6 +1,6 @@ # Release History -## 5.10.1 (Unreleased) +## 5.10.1 (2022-08-18) This version and all future versions will require Python 3.7+, Python 3.6 is no longer supported. @@ -12,6 +12,7 @@ This version and all future versions will require Python 3.7+, Python 3.6 is no ### Other Changes - Internal refactoring to support upcoming Pure Python AMQP-based release. +- Updated uAMQP dependency to 1.6.0. ## 5.10.0 (2022-06-08) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index ea8f460cb353..61e59688cbc4 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -9,7 +9,11 @@ import time import functools import collections -from typing import Any, Dict, Tuple, List, Optional, TYPE_CHECKING, TypeAlias, cast, Union +from typing import Any, Dict, Tuple, List, Optional, TYPE_CHECKING, cast, Union +try: + from typing import TypeAlias +except ImportError: + from typing_extensions import TypeAlias from datetime import timedelta from urllib.parse import urlparse diff --git a/sdk/eventhub/azure-eventhub/dev_requirements.txt b/sdk/eventhub/azure-eventhub/dev_requirements.txt index 308c4c22ead4..eef3c3014389 100644 --- a/sdk/eventhub/azure-eventhub/dev_requirements.txt +++ b/sdk/eventhub/azure-eventhub/dev_requirements.txt @@ -4,4 +4,4 @@ azure-mgmt-eventhub==10.0.0 azure-mgmt-resource==20.0.0 aiohttp>=3.0 --e ../../../tools/azure-devtools \ No newline at end of file +-e ../../../tools/azure-devtools diff --git a/sdk/eventhub/azure-eventhub/setup.py b/sdk/eventhub/azure-eventhub/setup.py index f7c5843f4a68..b7d2c5294da4 100644 --- a/sdk/eventhub/azure-eventhub/setup.py +++ b/sdk/eventhub/azure-eventhub/setup.py @@ -69,7 +69,7 @@ packages=find_packages(exclude=exclude_packages), install_requires=[ "azure-core<2.0.0,>=1.14.0", - "uamqp>=1.5.1,<2.0.0", + "uamqp>=1.6.0,<2.0.0", "typing-extensions>=4.0.1", ] ) diff --git a/shared_requirements.txt b/shared_requirements.txt index dc06fe51f8de..b967e2e90805 100644 --- a/shared_requirements.txt +++ b/shared_requirements.txt @@ -118,7 +118,7 @@ msrest>=0.6.21 msrestazure<2.0.0,>=0.4.32 azure-mgmt-core<2.0.0,>=1.3.0 requests>=2.18.4 -uamqp~=1.5.0 +uamqp~=1.6.0 enum34>=1.0.4 certifi>=2017.4.17 aiohttp>=3.0 @@ -175,7 +175,7 @@ opentelemetry-sdk<2.0.0,>=1.5.0,!=1.10a0 #override azure-eventhub-checkpointstoreblob-aio azure-core<2.0.0,>=1.20.1 #override azure-eventhub-checkpointstoreblob-aio aiohttp<4.0,>=3.0 #override azure-eventhub-checkpointstoretable azure-core<2.0.0,>=1.14.0 -#override azure-eventhub uamqp>=1.5.1,<2.0.0 +#override azure-eventhub uamqp>=1.6.0,<2.0.0 #override azure-appconfiguration msrest>=0.6.10 #override azure-servicebus uamqp>=1.5.1,<2.0.0 #override azure-servicebus msrest>=0.6.17,<2.0.0 From 73fdaa820306f666c7f6c3a84155f90b2074a97e Mon Sep 17 00:00:00 2001 From: swathipil Date: Thu, 18 Aug 2022 15:33:16 -0700 Subject: [PATCH 19/20] update message prop back to ivar --- .../azure-eventhub/azure/eventhub/_common.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 9e0cb03730b8..019f7c1d5b44 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -126,6 +126,7 @@ def __init__( ) # amqp message to be reset right before sending self._message = UamqpTransport.to_outgoing_amqp_message(self._raw_amqp_message) + self.message = self._message self._raw_amqp_message.header = AmqpMessageHeader() self._raw_amqp_message.properties = AmqpMessageProperties() self.message_id = None @@ -242,14 +243,6 @@ def _decode_non_data_body_as_str(self, encoding: str = "UTF-8") -> str: seq_list = [d for seq_section in body for d in seq_section] return str(decode_with_recurse(seq_list, encoding)) - @property - def message(self) -> uamqp_Message: - return self._message - - @message.setter - def message(self, value: uamqp_Message) -> None: - self._message = value - @property def raw_amqp_message(self) -> AmqpAnnotatedMessage: """Advanced usage only. The internal AMQP message payload that is sent or received.""" @@ -523,6 +516,7 @@ def __init__( self._message = self._amqp_transport.set_message_partition_key( self._message, self._partition_key ) + self.message: uamqp_BatchMessage = self._message self._size = self._amqp_transport.get_batch_message_encoded_size(self._message) self._count = 0 self._internal_events: List[ @@ -577,14 +571,6 @@ def size_in_bytes(self) -> int: """ return self._size - @property - def message(self) -> uamqp_BatchMessage: - return self._message - - @message.setter - def message(self, value: uamqp_BatchMessage) -> None: - self._message = value - def add(self, event_data: Union[EventData, AmqpAnnotatedMessage]) -> None: """Try to add an EventData to the batch. From f05f36226210ff1e53911499e19fa320b077864a Mon Sep 17 00:00:00 2001 From: swathipil Date: Thu, 18 Aug 2022 17:23:28 -0700 Subject: [PATCH 20/20] set message ivar when creating ED._from_message --- sdk/eventhub/azure-eventhub/azure/eventhub/_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 019f7c1d5b44..969d9c5bfb5f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -225,6 +225,7 @@ def _from_message( event_data = cls(body="") # pylint: disable=protected-access event_data._message = message + event_data.message = message event_data._raw_amqp_message = ( raw_amqp_message if raw_amqp_message