diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 631a0b4e1aff..633a90f8d0a9 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -146,35 +146,42 @@ def __repr__(self) -> str: # pylint: disable=bare-except try: body_str = self.body_as_str() - except: + except Exception as e: # pylint: disable=broad-except + _LOGGER.debug("Message body read error: %r", e) body_str = "" event_repr = f"body='{body_str}'" try: event_repr += f", properties={self.properties}" - except: + except Exception as e: # pylint: disable=broad-except + _LOGGER.debug("Message properties read error: %r", e) event_repr += ", properties=" try: event_repr += f", offset={self.offset}" - except: + except Exception as e: # pylint: disable=broad-except + _LOGGER.debug("Message offset read error: %r", e) event_repr += ", offset=" try: event_repr += f", sequence_number={self.sequence_number}" - except: + except Exception as e: # pylint: disable=broad-except + _LOGGER.debug("Message sequence number read error: %r", e) event_repr += ", sequence_number=" try: event_repr += f", partition_key={self.partition_key!r}" - except: + except Exception as e: # pylint: disable=broad-except + _LOGGER.debug("Message partition key read error: %r", e) event_repr += ", partition_key=" try: event_repr += f", enqueued_time={self.enqueued_time!r}" - except: + except Exception as e: # pylint: disable=broad-except + _LOGGER.debug("Message enqueued time read error: %r", e) event_repr += ", enqueued_time=" return f"EventData({event_repr})" def __str__(self) -> str: try: body_str = self.body_as_str() - except: # pylint: disable=bare-except + except Exception as e: # pylint: disable=broad-except + _LOGGER.debug("Message body read error: %r", e) body_str = "" event_str = f"{{ body: '{body_str}'" try: @@ -187,8 +194,8 @@ def __str__(self) -> str: event_str += f", partition_key={self.partition_key!r}" if self.enqueued_time: event_str += f", enqueued_time={self.enqueued_time!r}" - except: # pylint: disable=bare-except - pass + except Exception as e: # pylint: disable=broad-except + _LOGGER.debug("Message metadata read error: %r", e) event_str += " }" return event_str @@ -416,9 +423,11 @@ def body_as_str(self, encoding: str = "UTF-8") -> str: if self.body_type != AmqpMessageBodyType.DATA: return self._decode_non_data_body_as_str(encoding=encoding) return "".join(b.decode(encoding) for b in cast(Iterable[bytes], data)) - except TypeError: + except UnicodeDecodeError as e: + raise TypeError(f"Message data is not compatible with string type: {e}") + except TypeError as e: return str(data) - except: # pylint: disable=bare-except + except Exception: # pylint: disable=broad-except pass try: return cast(bytes, data).decode(encoding) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py index 44f4e677d3fe..4412e9f58733 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py @@ -104,6 +104,9 @@ def __init__(self, endpoint, **kwargs): # pylint:disable=too-many-statements custom_parsed_url = urlparse(custom_endpoint_address) custom_port = custom_parsed_url.port or WEBSOCKET_PORT custom_endpoint = f"{custom_parsed_url.hostname}:{custom_port}{custom_parsed_url.path}" + self._container_id = kwargs.pop("container_id", None) or str(uuid.uuid4()) # type: str + self._network_trace = kwargs.get("network_trace", False) + self._network_trace_params = {"amqpConnection": self._container_id, "amqpSession": None, "amqpLink": None} transport = kwargs.get("transport") self._transport_type = kwargs.pop("transport_type", TransportType.Amqp) @@ -115,12 +118,18 @@ def __init__(self, endpoint, **kwargs): # pylint:disable=too-many-statements sasl_transport = SASLWithWebSocket endpoint = parsed_url.hostname + parsed_url.path self._transport = sasl_transport( - host=endpoint, credential=kwargs["sasl_credential"], custom_endpoint=custom_endpoint, **kwargs + host=endpoint, + credential=kwargs["sasl_credential"], + custom_endpoint=custom_endpoint, + network_trace_params=self._network_trace_params, + **kwargs ) else: - self._transport = Transport(parsed_url.netloc, transport_type=self._transport_type, **kwargs) - - self._container_id = kwargs.pop("container_id", None) or str(uuid.uuid4()) # type: str + self._transport = Transport( + parsed_url.netloc, + transport_type=self._transport_type, + network_trace_params=self._network_trace_params, + **kwargs) self._max_frame_size = kwargs.pop("max_frame_size", MAX_FRAME_SIZE_BYTES) # type: int self._remote_max_frame_size = None # type: Optional[int] self._channel_max = kwargs.pop("channel_max", MAX_CHANNELS) # type: int @@ -138,8 +147,6 @@ def __init__(self, endpoint, **kwargs): # pylint:disable=too-many-statements self._last_frame_received_time = None # type: Optional[float] self._last_frame_sent_time = None # type: Optional[float] self._idle_wait_time = kwargs.get("idle_wait_time", 0.1) # type: float - self._network_trace = kwargs.get("network_trace", False) - self._network_trace_params = {"connection": self._container_id, "session": None, "link": None} self._error = None self._outgoing_endpoints = {} # type: Dict[int, Session] self._incoming_endpoints = {} # type: Dict[int, Session] @@ -158,7 +165,12 @@ def _set_state(self, new_state): return previous_state = self.state self.state = new_state - _LOGGER.info("Connection '%s' state changed: %r -> %r", self._container_id, previous_state, new_state) + _LOGGER.info( + "Connection state changed: %r -> %r", + previous_state, + new_state, + extra=self._network_trace_params + ) for session in self._outgoing_endpoints.values(): session._on_connection_state_change() # pylint:disable=protected-access @@ -181,7 +193,7 @@ def _connect(self): self._set_state(ConnectionState.HDR_SENT) if not self._allow_pipelined_open: # TODO: List/tuple expected as variable args - self._process_incoming_frame(*self._read_frame(wait=True)) # type: ignore + self._read_frame(wait=True) if self.state != ConnectionState.HDR_EXCH: self._disconnect() raise ValueError("Did not receive reciprocal protocol header. Disconnecting.") @@ -211,9 +223,9 @@ def _can_read(self): """Whether the connection is in a state where it is legal to read for incoming frames.""" return self.state not in (ConnectionState.CLOSE_RCVD, ConnectionState.END) - def _read_frame( # type: ignore # TODO: missing return + def _read_frame( self, wait: Union[bool, float] = True, **kwargs: Any - ) -> Tuple[int, Optional[Tuple[int, NamedTuple]]]: + ) -> bool: """Read an incoming frame from the transport. :param Union[bool, float] wait: Whether to block on the socket while waiting for an incoming frame. @@ -224,16 +236,15 @@ def _read_frame( # type: ignore # TODO: missing return :returns: A tuple with the incoming channel number, and the frame in the form or a tuple of performative descriptor and field values. """ - if self._can_read(): - if wait is False: - return self._transport.receive_frame(**kwargs) - if wait is True: - with self._transport.block(): - return self._transport.receive_frame(**kwargs) - else: - with self._transport.block_with_timeout(timeout=wait): - return self._transport.receive_frame(**kwargs) - _LOGGER.warning("Cannot read frame in current state: %r", self.state) + if wait is False: + new_frame = self._transport.receive_frame(**kwargs) + elif wait is True: + with self._transport.block(): + new_frame = self._transport.receive_frame(**kwargs) + else: + with self._transport.block_with_timeout(timeout=wait): + new_frame = self._transport.receive_frame(**kwargs) + return self._process_incoming_frame(*new_frame) def _can_write(self): # type: () -> bool @@ -271,7 +282,7 @@ def _send_frame(self, channel, frame, timeout=None, **kwargs): except Exception: # pylint:disable=try-except-raise raise else: - _LOGGER.warning("Cannot write frame in current state: %r", self.state) + _LOGGER.info("Cannot write frame in current state: %r", self.state, extra=self._network_trace_params) def _get_next_outgoing_channel(self): # type: () -> int @@ -290,7 +301,7 @@ def _outgoing_empty(self): # type: () -> None """Send an empty frame to prevent the connection from reaching an idle timeout.""" if self._network_trace: - _LOGGER.info("-> empty()", extra=self._network_trace_params) + _LOGGER.debug("-> EmptyFrame()", extra=self._network_trace_params) try: raise self._error except TypeError: @@ -313,14 +324,14 @@ def _outgoing_header(self): """Send the AMQP protocol header to initiate the connection.""" self._last_frame_sent_time = time.time() if self._network_trace: - _LOGGER.info("-> header(%r)", HEADER_FRAME, extra=self._network_trace_params) + _LOGGER.debug("-> Header(%r)", HEADER_FRAME, extra=self._network_trace_params) self._transport.write(HEADER_FRAME) def _incoming_header(self, _, frame): # type: (int, bytes) -> None """Process an incoming AMQP protocol header and update the connection state.""" if self._network_trace: - _LOGGER.info("<- header(%r)", frame, extra=self._network_trace_params) + _LOGGER.debug("<- Header(%r)", frame, extra=self._network_trace_params) if self.state == ConnectionState.START: self._set_state(ConnectionState.HDR_RCVD) elif self.state == ConnectionState.HDR_SENT: @@ -344,7 +355,7 @@ def _outgoing_open(self): properties=self._properties, ) if self._network_trace: - _LOGGER.info("-> %r", open_frame, extra=self._network_trace_params) + _LOGGER.debug("-> %r", open_frame, extra=self._network_trace_params) self._send_frame(0, open_frame) def _incoming_open(self, channel, frame): @@ -371,9 +382,9 @@ def _incoming_open(self, channel, frame): """ # TODO: Add type hints for full frame tuple contents. if self._network_trace: - _LOGGER.info("<- %r", OpenFrame(*frame), extra=self._network_trace_params) + _LOGGER.debug("<- %r", OpenFrame(*frame), extra=self._network_trace_params) if channel != 0: - _LOGGER.error("OPEN frame received on a channel that is not 0.") + _LOGGER.error("OPEN frame received on a channel that is not 0.", extra=self._network_trace_params) self.close( error=AMQPError( condition=ErrorCondition.NotAllowed, description="OPEN frame received on a channel that is not 0." @@ -381,7 +392,7 @@ def _incoming_open(self, channel, frame): ) self._set_state(ConnectionState.END) if self.state == ConnectionState.OPENED: - _LOGGER.error("OPEN frame received in the OPENED state.") + _LOGGER.error("OPEN frame received in the OPENED state.", extra=self._network_trace_params) self.close() if frame[4]: self._remote_idle_timeout = frame[4] / 1000 # Convert to seconds @@ -394,19 +405,17 @@ def _incoming_open(self, channel, frame): # If any of the values in the received open frame are invalid then the connection shall be closed. # The error amqp:invalid-field shall be set in the error.condition field of the CLOSE frame. self.close( - error=cast( - AMQPError, - AMQPConnectionError( - condition=ErrorCondition.InvalidField, - description="Failed parsing OPEN frame: Max frame size is less than supported minimum.", - ), + error=AMQPError( + condition=ErrorCondition.InvalidField, + description="Failed parsing OPEN frame: Max frame size is less than supported minimum.", ) ) _LOGGER.error( - "Failed parsing OPEN frame: Max frame size is less than supported minimum." + "Failed parsing OPEN frame: Max frame size is less than supported minimum.", + extra=self._network_trace_params ) - else: - self._remote_max_frame_size = frame[2] + return + self._remote_max_frame_size = frame[2] if self.state == ConnectionState.OPEN_SENT: self._set_state(ConnectionState.OPENED) elif self.state == ConnectionState.HDR_EXCH: @@ -420,14 +429,14 @@ def _incoming_open(self, channel, frame): description=f"connection is an illegal state: {self.state}", ) ) - _LOGGER.error("connection is an illegal state: %r", self.state) + _LOGGER.error("Connection is an illegal state: %r", self.state, extra=self._network_trace_params) def _outgoing_close(self, error=None): # type: (Optional[AMQPError]) -> None """Send a Close frame to shutdown connection with optional error information.""" close_frame = CloseFrame(error=error) if self._network_trace: - _LOGGER.info("-> %r", close_frame, extra=self._network_trace_params) + _LOGGER.debug("-> %r", close_frame, extra=self._network_trace_params) self._send_frame(0, close_frame) def _incoming_close(self, channel, frame): @@ -440,7 +449,7 @@ def _incoming_close(self, channel, frame): """ if self._network_trace: - _LOGGER.info("<- %r", CloseFrame(*frame), extra=self._network_trace_params) + _LOGGER.debug("<- %r", CloseFrame(*frame), extra=self._network_trace_params) disconnect_states = [ ConnectionState.HDR_RCVD, ConnectionState.HDR_EXCH, @@ -450,25 +459,27 @@ def _incoming_close(self, channel, frame): ] if self.state in disconnect_states: self._disconnect() - self._set_state(ConnectionState.END) return close_error = None if channel > self._channel_max: - _LOGGER.error("Invalid channel") + _LOGGER.error( + "CLOSE frame received on a channel greated than support max.", + extra=self._network_trace_params + ) close_error = AMQPError(condition=ErrorCondition.InvalidField, description="Invalid channel", info=None) self._set_state(ConnectionState.CLOSE_RCVD) self._outgoing_close(error=close_error) self._disconnect() - self._set_state(ConnectionState.END) if frame[0]: self._error = AMQPConnectionError( condition=frame[0][0], description=frame[0][1], info=frame[0][2] ) _LOGGER.error( - "Connection error: %r", frame[0] + "Connection closed with error: %r", frame[0], + extra=self._network_trace_params ) @@ -527,6 +538,10 @@ def _incoming_end(self, channel, frame): condition=ErrorCondition.ConnectionCloseForced, description="Invalid channel number received" )) + _LOGGER.error( + "END frame received on invalid channel. Closing connection.", + extra=self._network_trace_params + ) return def _process_incoming_frame(self, channel, frame): # pylint:disable=too-many-return-statements @@ -591,7 +606,7 @@ def _process_incoming_frame(self, channel, frame): # pylint:disable=too-many-re return True if performative == 1: return False - _LOGGER.error("Unrecognized incoming frame: %s", frame) + _LOGGER.error("Unrecognized incoming frame: %r", frame, extra=self._network_trace_params) return True except KeyError: return True # TODO: channel error @@ -619,6 +634,10 @@ def _process_outgoing_frame(self, channel, frame): cast(float, self._idle_timeout), cast(float, self._last_frame_received_time), ) or self._get_remote_timeout(now): + _LOGGER.info( + "No frame received for the idle timeout. Closing connection.", + extra=self._network_trace_params + ) self.close( error=AMQPError( condition=ErrorCondition.ConnectionCloseForced, @@ -699,6 +718,10 @@ def listen(self, wait=False, batch=1, **kwargs): ) or self._get_remote_timeout( now ): + _LOGGER.info( + "No frame received for the idle timeout. Closing connection.", + extra=self._network_trace_params + ) self.close( error=AMQPError( condition=ErrorCondition.ConnectionCloseForced, @@ -713,13 +736,20 @@ def listen(self, wait=False, batch=1, **kwargs): ) return for _ in range(batch): - new_frame = self._read_frame(wait=wait, **kwargs) - if self._process_incoming_frame(*new_frame): + if self._can_read(): + if self._read_frame(wait=wait, **kwargs): + break + else: + _LOGGER.info( + "Connection cannot read frames in this state: %r", + self.state, + extra=self._network_trace_params + ) break except (OSError, IOError, SSLError, socket.error) as exc: self._error = AMQPConnectionError( ErrorCondition.SocketError, - description="Can not send frame out due to exception: " + str(exc), + description="Can not read frame due to exception: " + str(exc), error=exc, ) except Exception: # pylint:disable=try-except-raise @@ -793,13 +823,13 @@ def close(self, error=None, wait=False): :param bool wait: Whether to wait for a service Close response. Default is `False`. :rtype: None """ - if self.state in [ - ConnectionState.END, - ConnectionState.CLOSE_SENT, - ConnectionState.DISCARDING, - ]: - return try: + if self.state in [ + ConnectionState.END, + ConnectionState.CLOSE_SENT, + ConnectionState.DISCARDING, + ]: + return self._outgoing_close(error=error) if error: self._error = AMQPConnectionError( @@ -818,7 +848,7 @@ def close(self, error=None, wait=False): self._wait_for_response(wait, ConnectionState.END) except Exception as exc: # pylint:disable=broad-except # If error happened during closing, ignore the error and set state to END - _LOGGER.info("An error occurred when closing the connection: %r", exc) + _LOGGER.info("An error occurred when closing the connection: %r", exc, extra=self._network_trace_params) self._set_state(ConnectionState.END) finally: self._disconnect() diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index dfafa33c5935..570c8b5c0110 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -157,7 +157,7 @@ def __init__( read_timeout=None, socket_settings=None, raise_on_initial_eintr=True, - **kwargs # pylint: disable=unused-argument + **kwargs ): self._quick_recv = None self.connected = False @@ -165,6 +165,7 @@ def __init__( self.raise_on_initial_eintr = raise_on_initial_eintr self._read_buffer = BytesIO() self.host, self.port = to_host_port(host, port) + self.network_trace_params = kwargs.get('network_trace_params') self.connect_timeout = connect_timeout or TIMEOUT_INTERVAL self.read_timeout = read_timeout or READ_TIMEOUT_INTERVAL @@ -185,7 +186,8 @@ def connect(self): # EINTR, EAGAIN, EWOULDBLOCK would signal that the banner # has _not_ been sent self.connected = True - except (OSError, IOError, SSLError): + except (OSError, IOError, SSLError) as e: + _LOGGER.info("Transport connection failed: %r", e, extra=self.network_trace_params) # if not fully connected, close socket, and reraise error if self.sock and not self.connected: self.sock.close() @@ -387,76 +389,91 @@ def _write(self, s): raise NotImplementedError("Must be overriden in subclass") def close(self): - if self.sock is not None: - self._shutdown_transport() - # Call shutdown first to make sure that pending messages - # reach the AMQP broker if the program exits after - # calling this method. - try: - self.sock.shutdown(socket.SHUT_RDWR) - except Exception as exc: # pylint: disable=broad-except - # TODO: shutdown could raise OSError, Transport endpoint is not connected if the endpoint is already - # disconnected. can we safely ignore the errors since the close operation is initiated by us. - _LOGGER.info("Transport endpoint is already disconnected: %r", exc) - self.sock.close() - self.sock = None - self.connected = False + with self.socket_lock: + if self.sock is not None: + self._shutdown_transport() + # Call shutdown first to make sure that pending messages + # reach the AMQP broker if the program exits after + # calling this method. + try: + self.sock.shutdown(socket.SHUT_RDWR) + except Exception as exc: # pylint: disable=broad-except + # TODO: shutdown could raise OSError, Transport endpoint is not connected if the endpoint is already + # disconnected. can we safely ignore the errors since the close operation is initiated by us. + _LOGGER.debug( + "Transport endpoint is already disconnected: %r", + exc, + extra=self.network_trace_params + ) + self.sock.close() + self.sock = None + self.connected = False def read(self, verify_frame_type=0): - read = self._read - read_frame_buffer = BytesIO() - try: - frame_header = memoryview(bytearray(8)) - read_frame_buffer.write(read(8, buffer=frame_header, initial=True)) - - channel = struct.unpack(">H", frame_header[6:])[0] - size = frame_header[0:4] - if size == AMQP_FRAME: # Empty frame or AMQP header negotiation TODO - return frame_header, channel, None - size = struct.unpack(">I", size)[0] - offset = frame_header[4] - frame_type = frame_header[5] - if verify_frame_type is not None and frame_type != verify_frame_type: - _LOGGER.debug( - "Received invalid frame type: %r, expected: %r", frame_type, verify_frame_type - ) + with self.socket_lock: + read = self._read + read_frame_buffer = BytesIO() + try: + frame_header = memoryview(bytearray(8)) + read_frame_buffer.write(read(8, buffer=frame_header, initial=True)) + + channel = struct.unpack(">H", frame_header[6:])[0] + size = frame_header[0:4] + if size == AMQP_FRAME: # Empty frame or AMQP header negotiation TODO + return frame_header, channel, None + size = struct.unpack(">I", size)[0] + offset = frame_header[4] + frame_type = frame_header[5] + if verify_frame_type is not None and frame_type != verify_frame_type: + _LOGGER.debug( + "Received invalid frame type: %r, expected: %r", + frame_type, + verify_frame_type, + extra=self.network_trace_params + ) + raise ValueError( + f"Received invalid frame type: {frame_type}, expected: {verify_frame_type}" + ) - # >I is an unsigned int, but the argument to sock.recv is signed, - # so we know the size can be at most 2 * SIGNED_INT_MAX - payload_size = size - len(frame_header) - payload = memoryview(bytearray(payload_size)) - if size > SIGNED_INT_MAX: - read_frame_buffer.write(read(SIGNED_INT_MAX, buffer=payload)) - read_frame_buffer.write( - read(size - SIGNED_INT_MAX, buffer=payload[SIGNED_INT_MAX:]) - ) - else: - read_frame_buffer.write(read(payload_size, buffer=payload)) - except (socket.timeout, TimeoutError): - read_frame_buffer.write(self._read_buffer.getvalue()) - self._read_buffer = read_frame_buffer - self._read_buffer.seek(0) - raise - except (OSError, IOError, SSLError, socket.error) as exc: - # Don't disconnect for ssl read time outs - # http://bugs.python.org/issue10272 - if isinstance(exc, SSLError) and "timed out" in str(exc): - raise socket.timeout() - if get_errno(exc) not in _UNAVAIL: - self.connected = False - raise - offset -= 2 + # >I is an unsigned int, but the argument to sock.recv is signed, + # so we know the size can be at most 2 * SIGNED_INT_MAX + payload_size = size - len(frame_header) + payload = memoryview(bytearray(payload_size)) + if size > SIGNED_INT_MAX: + read_frame_buffer.write(read(SIGNED_INT_MAX, buffer=payload)) + read_frame_buffer.write( + read(size - SIGNED_INT_MAX, buffer=payload[SIGNED_INT_MAX:]) + ) + else: + read_frame_buffer.write(read(payload_size, buffer=payload)) + except (socket.timeout, TimeoutError): + read_frame_buffer.write(self._read_buffer.getvalue()) + self._read_buffer = read_frame_buffer + self._read_buffer.seek(0) + raise + except (OSError, IOError, SSLError, socket.error) as exc: + # Don't disconnect for ssl read time outs + # http://bugs.python.org/issue10272 + if isinstance(exc, SSLError) and "timed out" in str(exc): + raise socket.timeout() + if get_errno(exc) not in _UNAVAIL: + self.connected = False + _LOGGER.debug("Transport read failed: %r", exc, extra=self.network_trace_params) + raise + offset -= 2 return frame_header, channel, payload[offset:] def write(self, s): - try: - self._write(s) - except socket.timeout: - raise - except (OSError, IOError, socket.error) as exc: - if get_errno(exc) not in _UNAVAIL: - self.connected = False - raise + with self.socket_lock: + try: + self._write(s) + except socket.timeout: + raise + except (OSError, IOError, socket.error) as exc: + _LOGGER.debug("Transport write failed: %r", exc, extra=self.network_trace_params) + if get_errno(exc) not in _UNAVAIL: + self.connected = False + raise def receive_frame(self, **kwargs): try: @@ -719,14 +736,9 @@ def connect(self): # TODO: resolve pylance error when type: ignore is removed below, issue #22051 except (WebSocketTimeoutException, SSLError, WebSocketConnectionClosedException) as exc: # type: ignore self.close() - if isinstance(exc, WebSocketTimeoutException): - message = f'Send timed out ({str(exc)})' - elif isinstance(exc, SSLError): - message = f'Send disconnected by SSL ({str(exc)})' - else: - message = f'Send disconnected ({str(exc)})' - raise ConnectionError(message) - except (OSError, IOError, SSLError): + raise ConnectionError("Websocket failed to establish connection: %r" % exc) from exc + except (OSError, IOError, SSLError) as e: + _LOGGER.info("Websocket connection failed: %r", e, extra=self.network_trace_params) self.close() raise except ImportError: @@ -737,35 +749,43 @@ def connect(self): def _read(self, n, initial=False, buffer=None, _errnos=None): # pylint: disable=unused-argument """Read exactly n bytes from the peer.""" from websocket import WebSocketTimeoutException - - length = 0 - view = buffer or memoryview(bytearray(n)) - nbytes = self._read_buffer.readinto(view) - length += nbytes - n -= nbytes try: - while n: - data = self.ws.recv() - if len(data) <= n: - view[length : length + len(data)] = data - n -= len(data) - else: - view[length : length + n] = data[0:n] - self._read_buffer = BytesIO(data[n:]) - n = 0 - return view - except WebSocketTimeoutException as wte: - raise ConnectionError('Receive timed out (%s)' % wte) + length = 0 + view = buffer or memoryview(bytearray(n)) + nbytes = self._read_buffer.readinto(view) + length += nbytes + n -= nbytes + try: + while n: + data = self.ws.recv() + if len(data) <= n: + view[length : length + len(data)] = data + n -= len(data) + length += len(data) + else: + view[length : length + n] = data[0:n] + self._read_buffer = BytesIO(data[n:]) + n = 0 + return view + except AttributeError: + raise IOError("Websocket connection has already been closed.") + except WebSocketTimeoutException as wte: + raise TimeoutError('Websocket receive timed out (%s)' % wte) + except: + self._read_buffer = BytesIO(view[:length]) + raise def close(self): - if self.ws: - self._shutdown_transport() - self.ws = None + with self.socket_lock: + if self.ws: + self._shutdown_transport() + self.ws = None def _shutdown_transport(self): # TODO Sync and Async close functions named differently """Do any preliminary work in shutting down the connection.""" - self.ws.close() + if self.ws: + self.ws.close() def _write(self, s): """Completely write a string to the peer. @@ -776,10 +796,10 @@ def _write(self, s): from websocket import WebSocketConnectionClosedException, WebSocketTimeoutException try: self.ws.send_binary(s) + except AttributeError: + raise IOError("Websocket connection has already been closed.") except WebSocketTimeoutException as e: - raise ConnectionError('Send timed out (%s)' % e) - except SSLError as e: - raise ConnectionError('Send disconnected by SSL (%s)' % e) - except WebSocketConnectionClosedException as e: - raise ConnectionError('Send disconnected (%s)' % e) + raise socket.timeout('Websocket send timed out (%s)' % e) + except (WebSocketConnectionClosedException, SSLError) as e: + raise ConnectionError('Websocket disconnected: %r' % e) \ No newline at end of file diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_cbs_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_cbs_async.py index c4859df0e8ff..3906e8a145cd 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_cbs_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_cbs_async.py @@ -49,6 +49,11 @@ def __init__(self, session, auth, **kwargs): self._expires_on = None self._token = None self._refresh_window = None + self._network_trace_params = { + "amqpConnection": self._session._connection._container_id, + "amqpSession": self._session.name, + "amqpLink": None + } self._token_status_code = None self._token_status_description = None @@ -79,50 +84,57 @@ async def _put_token(self, token, token_type, audience, expires_on=None): async def _on_amqp_management_open_complete(self, management_open_result): if self.state in (CbsState.CLOSED, CbsState.ERROR): - _LOGGER.debug("CSB with status: %r encounters unexpected AMQP management open complete.", self.state) + _LOGGER.debug( + "CSB with status: %r encounters unexpected AMQP management open complete.", + self.state, + extra=self._network_trace_params + ) elif self.state == CbsState.OPEN: self.state = CbsState.ERROR _LOGGER.info( - "Unexpected AMQP management open complete in OPEN, CBS error occurred on connection %r.", - self._connection._container_id, # pylint:disable=protected-access + "Unexpected AMQP management open complete in OPEN, CBS error occurred.", + extra=self._network_trace_params ) elif self.state == CbsState.OPENING: self.state = CbsState.OPEN if management_open_result == ManagementOpenResult.OK else CbsState.CLOSED _LOGGER.info( - "CBS for connection %r completed opening with status: %r", - self._connection._container_id, # pylint: disable=protected-access + "CBS completed opening with status: %r", management_open_result, - ) # pylint:disable=protected-access + extra=self._network_trace_params + ) async def _on_amqp_management_error(self): if self.state == CbsState.CLOSED: - _LOGGER.debug("Unexpected AMQP error in CLOSED state.") + _LOGGER.debug("Unexpected AMQP error in CLOSED state.", extra=self._network_trace_params) elif self.state == CbsState.OPENING: self.state = CbsState.ERROR await self._mgmt_link.close() _LOGGER.info( - "CBS for connection %r failed to open with status: %r", - self._connection._container_id, + "CBS failed to open with status: %r", ManagementOpenResult.ERROR, - ) # pylint:disable=protected-access + extra=self._network_trace_params + ) elif self.state == CbsState.OPEN: self.state = CbsState.ERROR - _LOGGER.info( - "CBS error occurred on connection %r.", self._connection._container_id - ) # pylint:disable=protected-access + _LOGGER.info("CBS error occurred.", extra=self._network_trace_params) async def _on_execute_operation_complete( self, execute_operation_result, status_code, status_description, _, error_condition=None ): if error_condition: - _LOGGER.info("CBS Put token error: %r", error_condition) + _LOGGER.info( + "CBS Put token error: %r", + error_condition, + extra=self._network_trace_params + ) self.auth_state = CbsAuthState.ERROR return - _LOGGER.info( + _LOGGER.debug( "CBS Put token result (%r), status code: %s, status_description: %s.", execute_operation_result, status_code, status_description, + extra=self._network_trace_params ) self._token_status_code = status_code self._token_status_description = status_description @@ -139,17 +151,26 @@ async def _on_execute_operation_complete( async def _update_status(self): if self.auth_state == CbsAuthState.OK or self.auth_state == CbsAuthState.REFRESH_REQUIRED: - _LOGGER.debug("update_status In refresh required or OK.") is_expired, is_refresh_required = check_expiration_and_refresh_status( self._expires_on, self._refresh_window ) # pylint:disable=line-too-long - _LOGGER.debug("is expired == %r, is refresh required == %r", is_expired, is_refresh_required) + _LOGGER.debug( + "CBS status check: state == %r, expired == %r, refresh required == %r", + self.auth_state, + is_expired, + is_refresh_required, + extra=self._network_trace_params + ) if is_expired: self.auth_state = CbsAuthState.EXPIRED elif is_refresh_required: self.auth_state = CbsAuthState.REFRESH_REQUIRED elif self.auth_state == CbsAuthState.IN_PROGRESS: - _LOGGER.debug("In update status, in progress. token put time: %r", self._token_put_time) + _LOGGER.debug( + "CBS update in progress. Token put time: %r", + self._token_put_time, + extra=self._network_trace_params + ) put_timeout = check_put_timeout_status(self._auth_timeout, self._token_put_time) if put_timeout: self.auth_state = CbsAuthState.TIMEOUT @@ -176,8 +197,16 @@ async def close(self): async def update_token(self): self.auth_state = CbsAuthState.IN_PROGRESS access_token = await self._auth.get_token() - if not access_token.token: - _LOGGER.debug("update_token received an empty token") + if not access_token: + _LOGGER.info( + "Token refresh function received an empty token object.", + extra=self._network_trace_params + ) + elif not access_token.token: + _LOGGER.info( + "Token refresh function received an empty token.", + extra=self._network_trace_params + ) self._expires_on = access_token.expires_on expires_in = self._expires_on - int(utc_now().timestamp()) self._refresh_window = int(float(expires_in) * 0.1) @@ -203,8 +232,9 @@ async def handle_token(self): return True if self.auth_state == CbsAuthState.REFRESH_REQUIRED: _LOGGER.info( - "Token on connection %r will expire soon - attempting to refresh.", self._connection._container_id - ) # pylint:disable=protected-access + "Token will expire soon - attempting to refresh.", + extra=self._network_trace_params + ) await self.update_token() return False if self.auth_state == CbsAuthState.FAILURE: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py index ba6c24ad1125..30e8c685f1ce 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py @@ -142,14 +142,21 @@ async def _keep_alive_async(self): current_time = time.time() elapsed_time = current_time - start_time if elapsed_time >= self._keep_alive_interval: - _logger.info("Keeping %r connection alive. %r", - self.__class__.__name__, - self._connection.container_id) + _logger.debug( + "Keeping %r connection alive.", + self.__class__.__name__, + extra=self._network_trace_params + ) await asyncio.shield(self._connection.work_async()) start_time = current_time await asyncio.sleep(1) except Exception as e: # pylint: disable=broad-except - _logger.info("Connection keep-alive for %r failed: %r.", self.__class__.__name__, e) + _logger.info( + "Connection keep-alive for %r failed: %r.", + self.__class__.__name__, + e, + extra=self._network_trace_params + ) async def __aenter__(self): """Run Client in an async context manager.""" @@ -221,7 +228,6 @@ async def open_async(self, connection=None): # pylint: disable=protected-access if self._session: return # already open. - _logger.debug("Opening client connection.") if connection: self._connection = connection self._external_connection = True @@ -254,9 +260,13 @@ async def open_async(self, connection=None): auth_timeout=self._auth_timeout ) await self._cbs_authenticator.open() + self._network_trace_params["amqpConnection"] = self._connection._container_id + self._network_trace_params["amqpSession"] = self._session.name self._shutdown = False - if self._keep_alive_interval: - self._keep_alive_thread = asyncio.ensure_future(self._keep_alive_async()) + # TODO: Looks like this is broken - should re-enable later and test + # correct empty frame behaviour + # if self._keep_alive_interval: + # self._keep_alive_thread = asyncio.ensure_future(self._keep_alive_async()) async def close_async(self): """Close the client asynchronously. This includes closing the Session @@ -279,6 +289,8 @@ async def close_async(self): if not self._external_connection: await self._connection.close() self._connection = None + self._network_trace_params["amqpConnection"] = None + self._network_trace_params["amqpSession"] = None async def auth_complete_async(self): """Whether the authentication handshake is complete during @@ -318,6 +330,7 @@ async def do_work_async(self, **kwargs): :rtype: bool :raises: TimeoutError if CBS authentication timeout reached. """ + if self._shutdown: return False if not await self.client_ready_async(): @@ -493,13 +506,8 @@ async def _client_run_async(self, **kwargs): :rtype: bool """ - try: - await self._link.update_pending_deliveries() - await self._connection.listen(wait=self._socket_timeout, **kwargs) - except ValueError: - _logger.info("Timeout reached, closing sender.") - self._shutdown = True - return False + await self._link.update_pending_deliveries() + await self._connection.listen(wait=self._socket_timeout, **kwargs) return True async def _transfer_message_async(self, message_delivery, timeout=0): @@ -562,6 +570,11 @@ async def _send_message_impl_async(self, message, **kwargs): running = True while running and message_delivery.state not in MESSAGE_DELIVERY_DONE_STATES: running = await self.do_work_async() + if message_delivery.state not in MESSAGE_DELIVERY_DONE_STATES: + raise MessageException( + condition=ErrorCondition.ClientError, + description="Send failed - connection not running." + ) if message_delivery.state in ( MessageDeliveryState.Error, @@ -714,7 +727,7 @@ async def _client_run_async(self, **kwargs): await self._link.flow() await self._connection.listen(wait=self._socket_timeout, **kwargs) except ValueError: - _logger.info("Timeout reached, closing receiver.") + _logger.info("Timeout reached, closing receiver.", extra=self._network_trace_params) self._shutdown = True return False return True @@ -732,10 +745,6 @@ async def _message_received_async(self, frame, message): await self._message_received_callback(message) if not self._streaming_receive: self._received_messages.put((frame, message)) - # TODO: do we need settled property for a message? - # elif not message.settled: - # # Message was received with callback processing and wasn't settled. - # _logger.info("Message was not settled.") async def _receive_message_batch_impl_async(self, max_batch_size=None, on_message_received=None, timeout=0): self._message_received_callback = on_message_received @@ -746,7 +755,7 @@ async def _receive_message_batch_impl_async(self, max_batch_size=None, on_messag await self.open_async() while len(batch) < max_batch_size: try: - # TODO: This looses the transfer frame data + # TODO: This drops the transfer frame data _, message = self._received_messages.get_nowait() batch.append(message) self._received_messages.task_done() diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py index c816d04dc440..aaaf43bfe420 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py @@ -85,6 +85,11 @@ def __init__(self, endpoint, **kwargs): # pylint:disable=too-many-statements custom_parsed_url = urlparse(custom_endpoint_address) custom_port = custom_parsed_url.port or WEBSOCKET_PORT custom_endpoint = f"{custom_parsed_url.hostname}:{custom_port}{custom_parsed_url.path}" + self._container_id = kwargs.pop("container_id", None) or str( + uuid.uuid4() + ) # type: str + self._network_trace = kwargs.get("network_trace", False) + self._network_trace_params = {"amqpConnection": self._container_id, "amqpSession": None, "amqpLink": None} transport = kwargs.get("transport") self._transport_type = kwargs.pop("transport_type", TransportType.Amqp) @@ -101,14 +106,15 @@ def __init__(self, endpoint, **kwargs): # pylint:disable=too-many-statements host=endpoint, credential=kwargs["sasl_credential"], custom_endpoint=custom_endpoint, + network_trace_params=self._network_trace_params, **kwargs, ) else: - self._transport = AsyncTransport(parsed_url.netloc, **kwargs) + self._transport = AsyncTransport( + parsed_url.netloc, + network_trace_params=self._network_trace_params, + **kwargs) - self._container_id = kwargs.pop("container_id", None) or str( - uuid.uuid4() - ) # type: str self._max_frame_size = kwargs.pop( "max_frame_size", MAX_FRAME_SIZE_BYTES ) # type: int @@ -140,12 +146,6 @@ def __init__(self, endpoint, **kwargs): # pylint:disable=too-many-statements self._last_frame_received_time = None # type: Optional[float] self._last_frame_sent_time = None # type: Optional[float] self._idle_wait_time = kwargs.get("idle_wait_time", 0.1) # type: float - self._network_trace = kwargs.get("network_trace", False) - self._network_trace_params = { - "connection": self._container_id, - "session": None, - "link": None, - } self._error = None self._outgoing_endpoints = {} # type: Dict[int, Session] self._incoming_endpoints = {} # type: Dict[int, Session] @@ -165,10 +165,10 @@ async def _set_state(self, new_state): previous_state = self.state self.state = new_state _LOGGER.info( - "Connection '%s' state changed: %r -> %r", - self._container_id, + "Connection state changed: %r -> %r", previous_state, new_state, + extra=self._network_trace_params ) for session in self._outgoing_endpoints.values(): await session._on_connection_state_change() # pylint:disable=protected-access @@ -191,7 +191,7 @@ async def _connect(self): await self._outgoing_header() await self._set_state(ConnectionState.HDR_SENT) if not self._allow_pipelined_open: - await self._process_incoming_frame(*(await self._read_frame(wait=True))) + await self._read_frame(wait=True) if self.state != ConnectionState.HDR_EXCH: await self._disconnect() raise ValueError( @@ -223,8 +223,7 @@ def _can_read(self): """Whether the connection is in a state where it is legal to read for incoming frames.""" return self.state not in (ConnectionState.CLOSE_RCVD, ConnectionState.END) - async def _read_frame(self, wait=True, **kwargs): # type: ignore # TODO: missing return - # type: (bool, Any) -> Tuple[int, Optional[Tuple[int, NamedTuple]]] + async def _read_frame(self, wait: Union[bool, int, float] = True, **kwargs) -> bool: """Read an incoming frame from the transport. :param Union[bool, float] wait: Whether to block on the socket while waiting for an incoming frame. @@ -235,15 +234,15 @@ async def _read_frame(self, wait=True, **kwargs): # type: ignore # TODO: missin :returns: A tuple with the incoming channel number, and the frame in the form or a tuple of performative descriptor and field values. """ - if self._can_read(): - if wait is False: - timeout = 1 # TODO: What should this default be? - elif wait is True: - timeout = None - else: - timeout = wait - return await self._transport.receive_frame(timeout=timeout, **kwargs) - _LOGGER.warning("Cannot read frame in current state: %r", self.state) + timeout: Optional[Union[int, float]] = None + if wait is False: + timeout = 1 # TODO: What should this default be? + elif wait is True: + timeout = None + else: + timeout = wait + new_frame = await self._transport.receive_frame(timeout=timeout, **kwargs) + return await self._process_incoming_frame(*new_frame) def _can_write(self): # type: () -> bool @@ -284,7 +283,7 @@ async def _send_frame(self, channel, frame, timeout=None, **kwargs): error=exc, ) else: - _LOGGER.warning("Cannot write frame in current state: %r", self.state) + _LOGGER.info("Cannot write frame in current state: %r", self.state, extra=self._network_trace_params) def _get_next_outgoing_channel(self): # type: () -> int @@ -311,7 +310,7 @@ async def _outgoing_empty(self): # type: () -> None """Send an empty frame to prevent the connection from reaching an idle timeout.""" if self._network_trace: - _LOGGER.info("-> empty()", extra=self._network_trace_params) + _LOGGER.debug("-> EmptyFrame()", extra=self._network_trace_params) try: raise self._error except TypeError: @@ -332,16 +331,14 @@ async def _outgoing_header(self): """Send the AMQP protocol header to initiate the connection.""" self._last_frame_sent_time = time.time() if self._network_trace: - _LOGGER.info( - "-> header(%r)", HEADER_FRAME, extra=self._network_trace_params - ) + _LOGGER.debug("-> Header(%r)", HEADER_FRAME, extra=self._network_trace_params) await self._transport.write(HEADER_FRAME) async def _incoming_header(self, _, frame): # type: (int, bytes) -> None """Process an incoming AMQP protocol header and update the connection state.""" if self._network_trace: - _LOGGER.info("<- header(%r)", frame, extra=self._network_trace_params) + _LOGGER.debug("<- Header(%r)", frame, extra=self._network_trace_params) if self.state == ConnectionState.START: await self._set_state(ConnectionState.HDR_RCVD) elif self.state == ConnectionState.HDR_SENT: @@ -371,7 +368,7 @@ async def _outgoing_open(self): properties=self._properties, ) if self._network_trace: - _LOGGER.info("-> %r", open_frame, extra=self._network_trace_params) + _LOGGER.debug("-> %r", open_frame, extra=self._network_trace_params) await self._send_frame(0, open_frame) async def _incoming_open(self, channel, frame): @@ -398,9 +395,9 @@ async def _incoming_open(self, channel, frame): """ # TODO: Add type hints for full frame tuple contents. if self._network_trace: - _LOGGER.info("<- %r", OpenFrame(*frame), extra=self._network_trace_params) + _LOGGER.debug("<- %r", OpenFrame(*frame), extra=self._network_trace_params) if channel != 0: - _LOGGER.error("OPEN frame received on a channel that is not 0.") + _LOGGER.error("OPEN frame received on a channel that is not 0.", extra=self._network_trace_params) await self.close( error=AMQPError( condition=ErrorCondition.NotAllowed, @@ -409,7 +406,7 @@ async def _incoming_open(self, channel, frame): ) await self._set_state(ConnectionState.END) if self.state == ConnectionState.OPENED: - _LOGGER.error("OPEN frame received in the OPENED state.") + _LOGGER.error("OPEN frame received in the OPENED state.", extra=self._network_trace_params) await self.close() if frame[4]: self._remote_idle_timeout = frame[4] / 1000 # Convert to seconds @@ -422,19 +419,17 @@ async def _incoming_open(self, channel, frame): # If any of the values in the received open frame are invalid then the connection shall be closed. # The error amqp:invalid-field shall be set in the error.condition field of the CLOSE frame. await self.close( - error=cast( - AMQPError, - AMQPConnectionError( - condition=ErrorCondition.InvalidField, - description="Failed parsing OPEN frame: Max frame size is less than supported minimum.", - ), + error=AMQPError( + condition=ErrorCondition.InvalidField, + description="Failed parsing OPEN frame: Max frame size is less than supported minimum.", ) ) _LOGGER.error( - "Failed parsing OPEN frame: Max frame size is less than supported minimum." + "Failed parsing OPEN frame: Max frame size is less than supported minimum.", + extra=self._network_trace_params ) - else: - self._remote_max_frame_size = frame[2] + return + self._remote_max_frame_size = frame[2] if self.state == ConnectionState.OPEN_SENT: await self._set_state(ConnectionState.OPENED) elif self.state == ConnectionState.HDR_EXCH: @@ -445,17 +440,17 @@ async def _incoming_open(self, channel, frame): await self.close( error=AMQPError( condition=ErrorCondition.IllegalState, - description=f"connection is an illegal state: {self.state}", + description=f"Connection is an illegal state: {self.state}", ) ) - _LOGGER.error("connection is an illegal state: %r", self.state) + _LOGGER.error("Connection is an illegal state: %r", self.state, extra=self._network_trace_params) async def _outgoing_close(self, error=None): # type: (Optional[AMQPError]) -> None """Send a Close frame to shutdown connection with optional error information.""" close_frame = CloseFrame(error=error) if self._network_trace: - _LOGGER.info("-> %r", close_frame, extra=self._network_trace_params) + _LOGGER.debug("-> %r", close_frame, extra=self._network_trace_params) await self._send_frame(0, close_frame) async def _incoming_close(self, channel, frame): @@ -468,7 +463,7 @@ async def _incoming_close(self, channel, frame): """ if self._network_trace: - _LOGGER.info("<- %r", CloseFrame(*frame), extra=self._network_trace_params) + _LOGGER.debug("<- %r", CloseFrame(*frame), extra=self._network_trace_params) disconnect_states = [ ConnectionState.HDR_RCVD, ConnectionState.HDR_EXCH, @@ -478,12 +473,14 @@ async def _incoming_close(self, channel, frame): ] if self.state in disconnect_states: await self._disconnect() - await self._set_state(ConnectionState.END) return close_error = None if channel > self._channel_max: - _LOGGER.error("Invalid channel") + _LOGGER.error( + "CLOSE frame received on a channel greated than support max.", + extra=self._network_trace_params + ) close_error = AMQPError( condition=ErrorCondition.InvalidField, description="Invalid channel", @@ -493,14 +490,14 @@ async def _incoming_close(self, channel, frame): await self._set_state(ConnectionState.CLOSE_RCVD) await self._outgoing_close(error=close_error) await self._disconnect() - await self._set_state(ConnectionState.END) if frame[0]: self._error = AMQPConnectionError( condition=frame[0][0], description=frame[0][1], info=frame[0][2] ) _LOGGER.error( - "Connection error: %r",frame[0] + "Connection closed with error: %r", frame[0], + extra=self._network_trace_params ) async def _incoming_begin(self, channel, frame): @@ -558,6 +555,10 @@ async def _incoming_end(self, channel, frame): condition=ErrorCondition.ConnectionCloseForced, description="Invalid channel number received" )) + _LOGGER.error( + "END frame received on invalid channel. Closing connection.", + extra=self._network_trace_params + ) return async def _process_incoming_frame( @@ -624,7 +625,7 @@ async def _process_incoming_frame( return True if performative == 1: return False # TODO: incoming EMPTY - _LOGGER.error("Unrecognized incoming frame: %s", frame) + _LOGGER.error("Unrecognized incoming frame: %r", frame, extra=self._network_trace_params) return True except KeyError: return True # TODO: channel error @@ -652,6 +653,10 @@ async def _process_outgoing_frame(self, channel, frame): cast(float, self._idle_timeout), cast(float, self._last_frame_received_time), ) or (await self._get_remote_timeout(now)): + _LOGGER.info( + "No frame received for the idle timeout. Closing connection.", + extra=self._network_trace_params + ) await self.close( error=AMQPError( condition=ErrorCondition.ConnectionCloseForced, @@ -704,10 +709,6 @@ async def _wait_for_response(self, wait, end_state): await asyncio.sleep(self._idle_wait_time) await self.listen(wait=False) - async def _listen_one_frame(self, **kwargs): - new_frame = await self._read_frame(**kwargs) - return await self._process_incoming_frame(*new_frame) - async def listen(self, wait=False, batch=1, **kwargs): # type: (Union[float, int, bool], int, Any) -> None """Listen on the socket for incoming frames and process them. @@ -734,6 +735,10 @@ async def listen(self, wait=False, batch=1, **kwargs): cast(float, self._idle_timeout), cast(float, self._last_frame_received_time), ) or (await self._get_remote_timeout(now)): + _LOGGER.info( + "No frame received for the idle timeout. Closing connection.", + extra=self._network_trace_params + ) await self.close( error=AMQPError( condition=ErrorCondition.ConnectionCloseForced, @@ -750,15 +755,20 @@ async def listen(self, wait=False, batch=1, **kwargs): ) return for _ in range(batch): - if await asyncio.ensure_future( - self._listen_one_frame(wait=wait, **kwargs) - ): - # TODO: compare the perf difference between ensure_future and direct await + if self._can_read(): + if await self._read_frame(wait=wait, **kwargs): + break + else: + _LOGGER.info( + "Connection cannot read frames in this state: %r", + self.state, + extra=self._network_trace_params + ) break except (OSError, IOError, SSLError, socket.error) as exc: self._error = AMQPConnectionError( ErrorCondition.SocketError, - description="Can not send frame out due to exception: " + str(exc), + description="Can not read frame due to exception: " + str(exc), error=exc, ) @@ -829,13 +839,13 @@ async def close(self, error=None, wait=False): :param bool wait: Whether to wait for a service Close response. Default is `False`. :rtype: None """ - if self.state in [ - ConnectionState.END, - ConnectionState.CLOSE_SENT, - ConnectionState.DISCARDING, - ]: - return try: + if self.state in [ + ConnectionState.END, + ConnectionState.CLOSE_SENT, + ConnectionState.DISCARDING, + ]: + return await self._outgoing_close(error=error) if error: self._error = AMQPConnectionError( @@ -854,7 +864,7 @@ async def close(self, error=None, wait=False): await self._wait_for_response(wait, ConnectionState.END) except Exception as exc: # pylint:disable=broad-except # If error happened during closing, ignore the error and set state to END - _LOGGER.info("An error occurred when closing the connection: %r", exc) + _LOGGER.info("An error occurred when closing the connection: %r", exc, extra=self._network_trace_params) await self._set_state(ConnectionState.END) finally: await self._disconnect() diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_link_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_link_async.py index 174fb61ee128..8b8b015d294e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_link_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_link_async.py @@ -83,7 +83,7 @@ def __init__(self, session, handle, name, role, **kwargs): self.network_trace = kwargs["network_trace"] self.network_trace_params = kwargs["network_trace_params"] - self.network_trace_params["link"] = self.name + self.network_trace_params["amqpLink"] = self.name self._session = session self._is_closed = False self._on_link_state_change = kwargs.get("on_link_state_change") @@ -158,16 +158,16 @@ async def _outgoing_attach(self): properties=self.properties, ) if self.network_trace: - _LOGGER.info("-> %r", attach_frame, extra=self.network_trace_params) + _LOGGER.debug("-> %r", attach_frame, extra=self.network_trace_params) await self._session._outgoing_attach(attach_frame) # pylint: disable=protected-access async def _incoming_attach(self, frame): if self.network_trace: - _LOGGER.info("<- %r", AttachFrame(*frame), extra=self.network_trace_params) + _LOGGER.debug("<- %r", AttachFrame(*frame), extra=self.network_trace_params) if self._is_closed: raise ValueError("Invalid link") if not frame[5] or not frame[6]: - _LOGGER.info("Cannot get source or target. Detaching link") + _LOGGER.info("Cannot get source or target. Detaching link", extra=self.network_trace_params) await self._set_state(LinkState.DETACHED) raise ValueError("Invalid link") self.remote_handle = frame[1] # handle @@ -189,7 +189,7 @@ async def _incoming_attach(self, frame): frame[6] = Target(*frame[6]) await self._on_attach(AttachFrame(*frame)) except Exception as e: # pylint: disable=broad-except - _LOGGER.warning("Callback for link attach raised error: %s", e) + _LOGGER.warning("Callback for link attach raised error: %s", e, extra=self.network_trace_params) async def _outgoing_flow(self, **kwargs): flow_frame = { @@ -212,14 +212,14 @@ async def _incoming_disposition(self, frame): async def _outgoing_detach(self, close=False, error=None): detach_frame = DetachFrame(handle=self.handle, closed=close, error=error) if self.network_trace: - _LOGGER.info("-> %r", detach_frame, extra=self.network_trace_params) + _LOGGER.debug("-> %r", detach_frame, extra=self.network_trace_params) await self._session._outgoing_detach(detach_frame) # pylint: disable=protected-access if close: self._is_closed = True async def _incoming_detach(self, frame): if self.network_trace: - _LOGGER.info("<- %r", DetachFrame(*frame), extra=self.network_trace_params) + _LOGGER.debug("<- %r", DetachFrame(*frame), extra=self.network_trace_params) if self.state == LinkState.ATTACHED: await self._outgoing_detach(close=frame[1]) # closed elif frame[1] and not self._is_closed and self.state in [LinkState.ATTACH_SENT, LinkState.ATTACH_RCVD]: @@ -254,7 +254,7 @@ async def detach(self, close=False, error=None): await self._outgoing_detach(close=close, error=error) await self._set_state(LinkState.DETACH_SENT) except Exception as exc: # pylint: disable=broad-except - _LOGGER.info("An error occurred when detaching the link: %r", exc) + _LOGGER.info("An error occurred when detaching the link: %r", exc, extra=self.network_trace_params) await self._set_state(LinkState.DETACHED) async def flow(self, *, link_credit: Optional[int] = None, **kwargs) -> None: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_management_link_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_management_link_async.py index 3928f93d2ff7..94f3163accfd 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_management_link_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_management_link_async.py @@ -38,12 +38,14 @@ def __init__(self, session, endpoint, **kwargs): self.state = ManagementLinkState.IDLE self._pending_operations = [] self._session = session + self._network_trace_params = kwargs.get('network_trace_params') self._request_link: SenderLink = session.create_sender_link( endpoint, source_address=endpoint, on_link_state_change=self._on_sender_state_change, send_settle_mode=SenderSettleMode.Unsettled, rcv_settle_mode=ReceiverSettleMode.First, + network_trace=kwargs.get("network_trace", False) ) self._response_link: ReceiverLink = session.create_receiver_link( endpoint, @@ -52,6 +54,7 @@ def __init__(self, session, endpoint, **kwargs): on_transfer=self._on_message_received, send_settle_mode=SenderSettleMode.Unsettled, rcv_settle_mode=ReceiverSettleMode.First, + network_trace=kwargs.get("network_trace", False) ) self._on_amqp_management_error = kwargs.get("on_amqp_management_error") self._on_amqp_management_open_complete = kwargs.get("on_amqp_management_open_complete") @@ -70,7 +73,12 @@ async def __aexit__(self, *args): await self.close() async def _on_sender_state_change(self, previous_state, new_state): - _LOGGER.info("Management link sender state changed: %r -> %r", previous_state, new_state) + _LOGGER.info( + "Management link sender state changed: %r -> %r", + previous_state, + new_state, + extra=self._network_trace_params + ) if new_state == previous_state: return if self.state == ManagementLinkState.OPENING: @@ -95,7 +103,12 @@ async def _on_sender_state_change(self, previous_state, new_state): return async def _on_receiver_state_change(self, previous_state, new_state): - _LOGGER.info("Management link receiver state changed: %r -> %r", previous_state, new_state) + _LOGGER.info( + "Management link receiver state changed: %r -> %r", + previous_state, + new_state, + extra=self._network_trace_params + ) if new_state == previous_state: return if self.state == ManagementLinkState.OPENING: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_management_operation_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_management_operation_async.py index f7ebb5f667bf..e5830d7d0ff8 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_management_operation_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_management_operation_async.py @@ -28,6 +28,11 @@ def __init__(self, session, endpoint='$management', **kwargs): self._session = session self._connection = self._session._connection + self._network_trace_params = { + "amqpConnection": self._session._connection._container_id, + "amqpSession": self._session.name, + "amqpLink": None + } self._mgmt_link = self._session.create_request_response_link_pair( endpoint=endpoint, on_amqp_management_open_complete=self._on_amqp_management_open_complete, @@ -61,22 +66,22 @@ async def _on_execute_operation_complete( error=None ): _LOGGER.debug( - "mgmt operation completed, operation id: %r; operation_result: %r; status_code: %r; " - "status_description: %r, raw_message: %r, error: %r", + "Management operation completed, id: %r; result: %r; code: %r; description: %r, error: %r", operation_id, operation_result, status_code, status_description, - raw_message, - error + error, + extra=self._network_trace_params ) if operation_result in\ (ManagementExecuteOperationResult.ERROR, ManagementExecuteOperationResult.LINK_CLOSED): self._mgmt_error = error _LOGGER.error( - "Failed to complete mgmt operation due to error: %r. The management request message is: %r", - error, raw_message + "Failed to complete management operation due to error: %r.", + error, + extra=self._network_trace_params ) else: self._responses[operation_id] = (status_code, status_description, raw_message) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_receiver_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_receiver_async.py index a193d482d96a..7d3c6c540160 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_receiver_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_receiver_async.py @@ -41,13 +41,13 @@ async def _process_incoming_message(self, frame, message): try: return await self._on_transfer(frame, message) except Exception as e: # pylint: disable=broad-except - _LOGGER.error("Handler function failed with error: %r", e) + _LOGGER.error("Transfer callback function failed with error: %r", e, extra=self.network_trace_params) return None async def _incoming_attach(self, frame): await super(ReceiverLink, self)._incoming_attach(frame) if frame[9] is None: # initial_delivery_count - _LOGGER.info("Cannot get initial-delivery-count. Detaching link") + _LOGGER.info("Cannot get initial-delivery-count. Detaching link", extra=self.network_trace_params) await self._set_state(LinkState.DETACHED) # TODO: Send detach now? self.delivery_count = frame[9] self.current_link_credit = self.link_credit @@ -55,7 +55,7 @@ async def _incoming_attach(self, frame): async def _incoming_transfer(self, frame): if self.network_trace: - _LOGGER.debug("<- %r", TransferFrame(*frame), extra=self.network_trace_params) + _LOGGER.debug("<- %r", TransferFrame(payload=b"***", *frame[:-1]), extra=self.network_trace_params) self.current_link_credit -= 1 self.delivery_count += 1 self.received_delivery_id = frame[1] # delivery_id @@ -69,8 +69,6 @@ async def _incoming_transfer(self, frame): self._received_payload = bytearray() else: message = decode_payload(frame[11]) - if self.network_trace: - _LOGGER.debug(" %r", message, extra=self.network_trace_params) delivery_state = await self._process_incoming_message(frame, message) if not frame[4] and delivery_state: # settled await self._outgoing_disposition( diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sender_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sender_async.py index aaf00c58cc40..29a4c052baa3 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sender_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sender_async.py @@ -12,9 +12,6 @@ from .._encode import encode_payload from ._link_async import Link from ..constants import SessionTransferState, LinkDeliverySettleReason, LinkState, Role, SenderSettleMode, SessionState -from ..performatives import ( - TransferFrame, -) from ..error import AMQPLinkError, ErrorCondition, MessageException _LOGGER = logging.getLogger(__name__) @@ -30,13 +27,18 @@ def __init__(self, **kwargs): self.transfer_state = None self.timeout = kwargs.get("timeout") self.settled = kwargs.get("settled", False) + self._network_trace_params = kwargs.get('network_trace_params') async def on_settled(self, reason, state): if self.on_delivery_settled and not self.settled: try: await self.on_delivery_settled(reason, state) except Exception as e: # pylint:disable=broad-except - _LOGGER.warning("Message 'on_send_complete' callback failed: %r", e) + _LOGGER.warning( + "Message 'on_send_complete' callback failed: %r", + e, + extra=self._network_trace_params + ) self.settled = True @@ -76,7 +78,10 @@ async def _incoming_flow(self, frame): rcv_delivery_count = frame[5] # delivery_count if frame[4] is not None: # handle if rcv_link_credit is None or rcv_delivery_count is None: - _LOGGER.info("Unable to get link-credit or delivery-count from incoming ATTACH. Detaching link.") + _LOGGER.info( + "Unable to get link-credit or delivery-count from incoming ATTACH. Detaching link.", + extra=self.network_trace_params + ) await self._remove_pending_deliveries() await self._set_state(LinkState.DETACHED) # TODO: Send detach now? else: @@ -100,12 +105,10 @@ async def _outgoing_transfer(self, delivery): "batchable": None, "payload": output, } - if self.network_trace: - _LOGGER.debug( - "-> %r", TransferFrame(delivery_id="", **delivery.frame), extra=self.network_trace_params - ) - _LOGGER.debug(" %r", delivery.message, extra=self.network_trace_params) - await self._session._outgoing_transfer(delivery) # pylint:disable=protected-access + await self._session._outgoing_transfer( # pylint:disable=protected-access + delivery, + self.network_trace_params if self.network_trace else None + ) sent_and_settled = False if delivery.transfer_state == SessionTransferState.OKAY: self.delivery_count = delivery_count @@ -175,6 +178,7 @@ async def send_transfer(self, message, *, send_async=False, **kwargs): timeout=kwargs.get("timeout"), message=message, settled=settled, + network_trace_params=self.network_trace_params ) if self.current_link_credit == 0 or send_async: self._pending_deliveries.append(delivery) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_session_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_session_async.py index 13d54cdccd2e..fd1cb14218cf 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_session_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_session_async.py @@ -61,7 +61,7 @@ def __init__(self, connection, channel, **kwargs): self.idle_wait_time = kwargs.get("idle_wait_time", 0.1) self.network_trace = kwargs["network_trace"] self.network_trace_params = kwargs["network_trace_params"] - self.network_trace_params["session"] = self.name + self.network_trace_params["amqpSession"] = self.name self.links = {} self._connection = connection @@ -127,12 +127,12 @@ async def _outgoing_begin(self): properties=self.properties, ) if self.network_trace: - _LOGGER.info("-> %r", begin_frame, extra=self.network_trace_params) + _LOGGER.debug("-> %r", begin_frame, extra=self.network_trace_params) await self._connection._process_outgoing_frame(self.channel, begin_frame) # pylint: disable=protected-access async def _incoming_begin(self, frame): if self.network_trace: - _LOGGER.info("<- %r", BeginFrame(*frame), extra=self.network_trace_params) + _LOGGER.debug("<- %r", BeginFrame(*frame), extra=self.network_trace_params) self.handle_max = frame[4] # handle_max self.next_incoming_id = frame[1] # next_outgoing_id self.remote_incoming_window = frame[2] # incoming_window @@ -148,12 +148,12 @@ async def _incoming_begin(self, frame): async def _outgoing_end(self, error=None): end_frame = EndFrame(error=error) if self.network_trace: - _LOGGER.info("-> %r", end_frame, extra=self.network_trace_params) + _LOGGER.debug("-> %r", end_frame, extra=self.network_trace_params) await self._connection._process_outgoing_frame(self.channel, end_frame) # pylint: disable=protected-access async def _incoming_end(self, frame): if self.network_trace: - _LOGGER.info("<- %r", EndFrame(*frame), extra=self.network_trace_params) + _LOGGER.debug("<- %r", EndFrame(*frame), extra=self.network_trace_params) if self.state not in [ SessionState.END_RCVD, SessionState.END_SENT, @@ -177,6 +177,10 @@ async def _incoming_attach(self, frame): try: outgoing_handle = self._get_next_output_handle() except ValueError: + _LOGGER.error( + "Unable to attach new link - cannot allocate more handles.", + extra=self.network_trace_params + ) # detach the link that would have been set. await self.links[frame[0].decode("utf-8")].detach( error=AMQPError( @@ -194,8 +198,13 @@ async def _incoming_attach(self, frame): self.links[frame[0]] = new_link self._output_handles[outgoing_handle] = new_link self._input_handles[frame[1]] = new_link - except ValueError: + except ValueError as e: # Reject Link + _LOGGER.error( + "Unable to attach new link: %r", + e, + extra=self.network_trace_params + ) await self._input_handles[frame[1]].detach() async def _outgoing_flow(self, frame=None): @@ -210,12 +219,12 @@ async def _outgoing_flow(self, frame=None): ) flow_frame = FlowFrame(**link_flow) if self.network_trace: - _LOGGER.info("-> %r", flow_frame, extra=self.network_trace_params) + _LOGGER.debug("-> %r", flow_frame, extra=self.network_trace_params) await self._connection._process_outgoing_frame(self.channel, flow_frame) # pylint: disable=protected-access async def _incoming_flow(self, frame): if self.network_trace: - _LOGGER.info("<- %r", FlowFrame(*frame), extra=self.network_trace_params) + _LOGGER.debug("<- %r", FlowFrame(*frame), extra=self.network_trace_params) self.next_incoming_id = frame[2] # next_outgoing_id remote_incoming_id = frame[0] or self.next_outgoing_id # next_incoming_id TODO "initial-outgoing-id" self.remote_incoming_window = remote_incoming_id + frame[1] - self.next_outgoing_id # incoming_window @@ -227,7 +236,7 @@ async def _incoming_flow(self, frame): if self.remote_incoming_window > 0 and not link._is_closed: # pylint: disable=protected-access await link._incoming_flow(frame) # pylint: disable=protected-access - async def _outgoing_transfer(self, delivery): + async def _outgoing_transfer(self, delivery, network_trace_params): if self.state != SessionState.MAPPED: delivery.transfer_state = SessionTransferState.ERROR if self.remote_incoming_window <= 0: @@ -264,11 +273,24 @@ async def _outgoing_transfer(self, delivery): "resume": delivery.frame["resume"], "aborted": delivery.frame["aborted"], "batchable": delivery.frame["batchable"], - "payload": payload[start_idx : start_idx + available_frame_size], "delivery_id": self.next_outgoing_id, } + if network_trace_params: + # We determine the logging for the outgoing Transfer frames based on the source + # Link configuration rather than the Session, because it's only at the Session + # level that we can determine how many outgoing frames are needed and their + # delivery IDs. + # TODO: Obscuring the payload for now to investigate the potential for leaks. + _LOGGER.debug( + "-> %r", TransferFrame(payload=b"***", **tmp_delivery_frame), + extra=network_trace_params + ) await self._connection._process_outgoing_frame( # pylint: disable=protected-access - self.channel, TransferFrame(**tmp_delivery_frame) + self.channel, + TransferFrame( + payload=payload[start_idx : start_idx + available_frame_size], + **tmp_delivery_frame + ) ) start_idx += available_frame_size remaining_payload_cnt -= available_frame_size @@ -285,11 +307,21 @@ async def _outgoing_transfer(self, delivery): "resume": delivery.frame["resume"], "aborted": delivery.frame["aborted"], "batchable": delivery.frame["batchable"], - "payload": payload[start_idx:], "delivery_id": self.next_outgoing_id, } + if network_trace_params: + # We determine the logging for the outgoing Transfer frames based on the source + # Link configuration rather than the Session, because it's only at the Session + # level that we can determine how many outgoing frames are needed and their + # delivery IDs. + # TODO: Obscuring the payload for now to investigate the potential for leaks. + _LOGGER.debug( + "-> %r", TransferFrame(payload=b"***", **tmp_delivery_frame), + extra=network_trace_params + ) await self._connection._process_outgoing_frame( # pylint: disable=protected-access - self.channel, TransferFrame(**tmp_delivery_frame) + self.channel, + TransferFrame(payload=payload[start_idx:], **tmp_delivery_frame) ) self.next_outgoing_id += 1 self.remote_incoming_window -= 1 @@ -304,6 +336,10 @@ async def _incoming_transfer(self, frame): try: await self._input_handles[frame[0]]._incoming_transfer(frame) # pylint: disable=protected-access except KeyError: + _LOGGER.error( + "Received Transfer frame on unattached link. Ending session.", + extra=self.network_trace_params + ) await self._set_state(SessionState.DISCARDING) await self.end( error=AMQPError( @@ -321,7 +357,7 @@ async def _outgoing_disposition(self, frame): async def _incoming_disposition(self, frame): if self.network_trace: - _LOGGER.info("<- %r", DispositionFrame(*frame), extra=self.network_trace_params) + _LOGGER.debug("<- %r", DispositionFrame(*frame), extra=self.network_trace_params) for link in self._input_handles.values(): await link._incoming_disposition(frame) # pylint: disable=protected-access @@ -381,7 +417,7 @@ async def end(self, error=None, wait=False): await self._set_state(new_state) await self._wait_for_response(wait, SessionState.UNMAPPED) except Exception as exc: # pylint: disable=broad-except - _LOGGER.info("An error occurred when ending the session: %r", exc) + _LOGGER.info("An error occurred when ending the session: %r", exc, extra=self.network_trace_params) await self._set_state(SessionState.UNMAPPED) def create_receiver_link(self, source_address, **kwargs): @@ -417,5 +453,6 @@ def create_request_response_link_pair(self, endpoint, **kwargs): self, endpoint, network_trace=kwargs.pop("network_trace", self.network_trace), + network_trace_params=dict(self.network_trace_params), **kwargs, ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index 42bacb51ad7f..a5caf7418fc1 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -76,7 +76,6 @@ async def receive_frame(self, timeout=None, **kwargs): decoded = decode_empty_frame(header) else: decoded = decode_frame(payload) - _LOGGER.info("ICH%d <- %r", channel, decoded) return channel, decoded except ( TimeoutError, @@ -104,7 +103,13 @@ async def read(self, verify_frame_type=0): frame_type = frame_header[5] if verify_frame_type is not None and frame_type != verify_frame_type: _LOGGER.debug( - "Received invalid frame type: %r, expected: %r", frame_type, verify_frame_type + "Received invalid frame type: %r, expected: %r", + frame_type, + verify_frame_type, + extra=self.network_trace_params + ) + raise ValueError( + f"Received invalid frame type: {frame_type}, expected: {verify_frame_type}" ) # >I is an unsigned int, but the argument to sock.recv is signed, # so we know the size can be at most 2 * SIGNED_INT_MAX @@ -123,9 +128,13 @@ async def read(self, verify_frame_type=0): read_frame_buffer.write( await self._read(payload_size, buffer=payload) ) - except asyncio.CancelledError: # pylint: disable=try-except-raise - raise - except (TimeoutError, socket.timeout, asyncio.IncompleteReadError): + except ( + asyncio.CancelledError, + asyncio.TimeoutError, + TimeoutError, + socket.timeout, + asyncio.IncompleteReadError + ): read_frame_buffer.write(self._read_buffer.getvalue()) self._read_buffer = read_frame_buffer self._read_buffer.seek(0) @@ -137,10 +146,23 @@ async def read(self, verify_frame_type=0): raise socket.timeout() if get_errno(exc) not in _UNAVAIL: self.connected = False + _LOGGER.debug("Transport read failed: %r", exc, extra=self.network_trace_params) raise offset -= 2 return frame_header, channel, payload[offset:] + async def write(self, s): + async with self.socket_lock: + try: + await self._write(s) + except socket.timeout: + raise + except (OSError, IOError, socket.error) as exc: + _LOGGER.debug("Transport write failed: %r", exc, extra=self.network_trace_params) + if get_errno(exc) not in _UNAVAIL: + self.connected = False + raise + async def send_frame(self, channel, frame, **kwargs): header, performative = encode_frame(frame, **kwargs) if performative is None: @@ -150,7 +172,6 @@ async def send_frame(self, channel, frame, **kwargs): data = header + encoded_channel + performative await self.write(data) - # _LOGGER.info("OCH%d -> %r", channel, frame) def _build_ssl_opts(self, sslopts): if sslopts in [True, False, None, {}]: @@ -228,6 +249,7 @@ def __init__( self.socket_settings = socket_settings self.socket_lock = asyncio.Lock() self.sslopts = ssl_opts + self.network_trace_params = kwargs.get('network_trace_params') async def connect(self): try: @@ -245,7 +267,8 @@ async def connect(self): # EINTR, EAGAIN, EWOULDBLOCK would signal that the banner # has _not_ been sent self.connected = True - except (OSError, IOError, SSLError): + except (OSError, IOError, SSLError) as e: + _LOGGER.info("Transport connect failed: %r", e, extra=self.network_trace_params) # if not fully connected, close socket, and reraise error if self.sock and not self.connected: self.sock.close() @@ -369,6 +392,10 @@ async def _read( toread ) nbytes = toread + except AttributeError: + # This means that close() was called concurrently + # self.reader has been set to None. + raise IOError("Connection has already been closed") except asyncio.IncompleteReadError as exc: pbytes = len(exc.partial) view[nbytes : nbytes + pbytes] = exc.partial @@ -397,44 +424,30 @@ async def _read( async def _write(self, s): """Write a string out to the SSL socket fully.""" - self.writer.write(s) - await self.writer.drain() + try: + self.writer.write(s) + await self.writer.drain() + except AttributeError: + raise IOError("Connection has already been closed") async def close(self): - if self.writer is not None: - # Closing the writer closes the underlying socket. - self.writer.close() - if self.sslopts: - # see issue: https://github.com/encode/httpx/issues/914 - await asyncio.sleep(0) - self.writer.transport.abort() - await self.writer.wait_closed() + async with self.socket_lock: + try: + if self.writer is not None: + # Closing the writer closes the underlying socket. + self.writer.close() + if self.sslopts: + # see issue: https://github.com/encode/httpx/issues/914 + await asyncio.sleep(0) + self.writer.transport.abort() + await self.writer.wait_closed() + except Exception as e: # pylint: disable=broad-except + # Sometimes SSL raises APPLICATION_DATA_AFTER_CLOSE_NOTIFY here on close. + _LOGGER.debug("Error shutting down socket: %r", e, extra=self.network_trace_params) self.writer, self.reader = None, None self.sock = None self.connected = False - async def write(self, s): - try: - await self._write(s) - except socket.timeout: - raise - except (OSError, IOError, socket.error) as exc: - if get_errno(exc) not in _UNAVAIL: - self.connected = False - raise - - async def receive_frame_with_lock(self, **kwargs): - try: - async with self.socket_lock: - header, channel, payload = await self.read(**kwargs) - if not payload: - decoded = decode_empty_frame(header) - else: - decoded = decode_frame(payload) - return channel, decoded - except (socket.timeout, TimeoutError): - return None, None - async def negotiate(self): if not self.sslopts: return @@ -469,6 +482,7 @@ def __init__( self.session = None self._http_proxy = kwargs.get("http_proxy", None) self.connected = False + self.network_trace_params = kwargs.get('network_trace_params') async def connect(self): self.sslopts = self._build_ssl_opts(self.sslopts) @@ -521,63 +535,55 @@ async def connect(self): heartbeat=DEFAULT_WEBSOCKET_HEARTBEAT_SECONDS, ) except ClientConnectorError as exc: + _LOGGER.info("Websocket connect failed: %r", exc, extra=self.network_trace_params) if self._custom_endpoint: raise AuthenticationException( ErrorCondition.ClientError, description="Failed to authenticate the connection due to exception: " + str(exc), error=exc, ) + raise ConnectionError("Failed to establish websocket connection: " + str(exc)) self.connected = True except ImportError: raise ValueError( "Please install aiohttp library to use websocket transport." ) - except OSError as e: - await self.session.close() - raise ConnectionError('Websocket connection closed: %r' % e) from e - async def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-argument + async def _read(self, toread, buffer=None, **kwargs): # pylint: disable=unused-argument """Read exactly n bytes from the peer.""" - length = 0 - view = buffer or memoryview(bytearray(n)) + view = buffer or memoryview(bytearray(toread)) nbytes = self._read_buffer.readinto(view) length += nbytes - n -= nbytes - + toread -= nbytes try: - while n: + while toread: data = await self.ws.receive_bytes() - if len(data) <= n: - view[length : length + len(data)] = data - n -= len(data) + read_length = len(data) + if read_length <= toread: + view[length : length + read_length] = data + toread -= read_length + length += read_length else: - view[length : length + n] = data[0:n] - self._read_buffer = BytesIO(data[n:]) - n = 0 + view[length : length + toread] = data[0:toread] + self._read_buffer = BytesIO(data[toread:]) + toread = 0 return view - except asyncio.TimeoutError as te: - raise ConnectionError('Receive timed out (%s)' % te) - except OSError as e: - await self.session.close() - raise ConnectionError('Websocket connection closed: %r' % e) from e + except: + self._read_buffer = BytesIO(view[:length]) + raise async def close(self): """Do any preliminary work in shutting down the connection.""" - await self.ws.close() - await self.session.close() - self.connected = False + async with self.socket_lock: + await self.ws.close() + await self.session.close() + self.connected = False - async def write(self, s): + async def _write(self, s): """Completely write a string (byte array) to the peer. ABNF, OPCODE_BINARY = 0x2 See http://tools.ietf.org/html/rfc5234 http://tools.ietf.org/html/rfc6455#section-5.2 """ - try: - await self.ws.send_bytes(s) - except asyncio.TimeoutError as te: - raise ConnectionError('Send timed out (%s)' % te) - except OSError as e: - await self.session.close() - raise ConnectionError('Websocket connection closed: %r' % e) from e + await self.ws.send_bytes(s) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/cbs.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/cbs.py index 9270346faa6a..f2eb796b587c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/cbs.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/cbs.py @@ -66,6 +66,11 @@ def __init__(self, session, auth, **kwargs): self._expires_on = None self._token = None self._refresh_window = None + self._network_trace_params = { + "amqpConnection": self._session._connection._container_id, + "amqpSession": self._session.name, + "amqpLink": None + } self._token_status_code = None self._token_status_description = None @@ -99,12 +104,13 @@ def _on_amqp_management_open_complete(self, management_open_result): _LOGGER.debug( "CSB with status: %r encounters unexpected AMQP management open complete.", self.state, + extra=self._network_trace_params ) elif self.state == CbsState.OPEN: self.state = CbsState.ERROR _LOGGER.info( - "Unexpected AMQP management open complete in OPEN, CBS error occurred on connection %r.", - self._connection._container_id, # pylint:disable=protected-access + "Unexpected AMQP management open complete in OPEN, CBS error occurred.", + extra=self._network_trace_params ) elif self.state == CbsState.OPENING: self.state = ( @@ -112,28 +118,26 @@ def _on_amqp_management_open_complete(self, management_open_result): if management_open_result == ManagementOpenResult.OK else CbsState.CLOSED ) - _LOGGER.info( - "CBS for connection %r completed opening with status: %r", - self._connection._container_id, # pylint: disable=protected-access + _LOGGER.debug( + "CBS completed opening with status: %r", management_open_result, - ) # pylint:disable=protected-access + extra=self._network_trace_params + ) def _on_amqp_management_error(self): if self.state == CbsState.CLOSED: - _LOGGER.info("Unexpected AMQP error in CLOSED state.") + _LOGGER.info("Unexpected AMQP error in CLOSED state.", extra=self._network_trace_params) elif self.state == CbsState.OPENING: self.state = CbsState.ERROR self._mgmt_link.close() _LOGGER.info( - "CBS for connection %r failed to open with status: %r", - self._connection._container_id, + "CBS failed to open with status: %r", ManagementOpenResult.ERROR, - ) # pylint:disable=protected-access + extra=self._network_trace_params + ) elif self.state == CbsState.OPEN: self.state = CbsState.ERROR - _LOGGER.info( - "CBS error occurred on connection %r.", self._connection._container_id - ) # pylint:disable=protected-access + _LOGGER.info("CBS error occurred.", extra=self._network_trace_params) def _on_execute_operation_complete( self, @@ -144,14 +148,19 @@ def _on_execute_operation_complete( error_condition=None, ): if error_condition: - _LOGGER.info("CBS Put token error: %r", error_condition) + _LOGGER.info( + "CBS Put token error: %r", + error_condition, + extra=self._network_trace_params + ) self.auth_state = CbsAuthState.ERROR return - _LOGGER.info( + _LOGGER.debug( "CBS Put token result (%r), status code: %s, status_description: %s.", execute_operation_result, status_code, status_description, + extra=self._network_trace_params ) self._token_status_code = status_code self._token_status_description = status_description @@ -174,14 +183,15 @@ def _update_status(self): self.auth_state == CbsAuthState.OK or self.auth_state == CbsAuthState.REFRESH_REQUIRED ): - _LOGGER.debug("update_status In refresh required or OK.") is_expired, is_refresh_required = check_expiration_and_refresh_status( self._expires_on, self._refresh_window ) _LOGGER.debug( - "is expired == %r, is refresh required == %r", + "CBS status check: state == %r, expired == %r, refresh required == %r", + self.auth_state, is_expired, is_refresh_required, + extra=self._network_trace_params ) if is_expired: self.auth_state = CbsAuthState.EXPIRED @@ -189,8 +199,9 @@ def _update_status(self): self.auth_state = CbsAuthState.REFRESH_REQUIRED elif self.auth_state == CbsAuthState.IN_PROGRESS: _LOGGER.debug( - "In update status, in progress. token put time: %r", + "CBS update in progress. Token put time: %r", self._token_put_time, + extra=self._network_trace_params ) put_timeout = check_put_timeout_status( self._auth_timeout, self._token_put_time @@ -221,9 +232,15 @@ def update_token(self): self.auth_state = CbsAuthState.IN_PROGRESS access_token = self._auth.get_token() if not access_token: - _LOGGER.debug("Update_token received an empty token object") + _LOGGER.info( + "Token refresh function received an empty token object.", + extra=self._network_trace_params + ) elif not access_token.token: - _LOGGER.debug("Update_token received an empty token") + _LOGGER.info( + "Token refresh function received an empty token.", + extra=self._network_trace_params + ) self._expires_on = access_token.expires_on expires_in = self._expires_on - int(utc_now().timestamp()) self._refresh_window = int(float(expires_in) * 0.1) @@ -252,9 +269,9 @@ def handle_token(self): return True if self.auth_state == CbsAuthState.REFRESH_REQUIRED: _LOGGER.info( - "Token on connection %r will expire soon - attempting to refresh.", - self._connection._container_id, - ) # pylint:disable=protected-access + "Token will expire soon - attempting to refresh.", + extra=self._network_trace_params + ) self.update_token() return False if self.auth_state == CbsAuthState.FAILURE: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py index 8a17b202ef4e..2f3bf5c668d0 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py @@ -176,6 +176,7 @@ def __init__(self, hostname, **kwargs): "remote_idle_timeout_empty_frame_send_ratio", None ) self._network_trace = kwargs.pop("network_trace", False) + self._network_trace_params = {"amqpConnection": None, "amqpSession": None, "amqpLink": None} # Session settings self._outgoing_window = kwargs.pop("outgoing_window", OUTGOING_WINDOW) @@ -280,7 +281,6 @@ def open(self, connection=None): # pylint: disable=protected-access if self._session: return # already open. - _logger.debug("Opening client connection.") if connection: self._connection = connection self._external_connection = True @@ -311,6 +311,8 @@ def open(self, connection=None): session=self._session, auth=self._auth, auth_timeout=self._auth_timeout ) self._cbs_authenticator.open() + self._network_trace_params["amqpConnection"] = self._connection._container_id + self._network_trace_params["amqpSession"] = self._session.name self._shutdown = False def close(self): @@ -337,6 +339,8 @@ def close(self): if not self._external_connection: self._connection.close() self._connection = None + self._network_trace_params["amqpConnection"] = None + self._network_trace_params["amqpSession"] = None def auth_complete(self): """Whether the authentication handshake is complete during @@ -555,13 +559,8 @@ def _client_run(self, **kwargs): :rtype: bool """ - try: - self._link.update_pending_deliveries() - self._connection.listen(wait=self._socket_timeout, **kwargs) - except ValueError: - _logger.info("Timeout reached, closing sender.") - self._shutdown = True - return False + self._link.update_pending_deliveries() + self._connection.listen(wait=self._socket_timeout, **kwargs) return True def _transfer_message(self, message_delivery, timeout=0): @@ -631,6 +630,12 @@ def _send_message_impl(self, message, **kwargs): running = True while running and message_delivery.state not in MESSAGE_DELIVERY_DONE_STATES: running = self.do_work() + if message_delivery.state not in MESSAGE_DELIVERY_DONE_STATES: + raise MessageException( + condition=ErrorCondition.ClientError, + description="Send failed - connection not running." + ) + if message_delivery.state in ( MessageDeliveryState.Error, MessageDeliveryState.Cancelled, @@ -797,7 +802,7 @@ def _client_run(self, **kwargs): self._link.flow() self._connection.listen(wait=self._socket_timeout, **kwargs) except ValueError: - _logger.info("Timeout reached, closing receiver.") + _logger.info("Timeout reached, closing receiver.", extra=self._network_trace_params) self._shutdown = True return False return True @@ -827,7 +832,7 @@ def _receive_message_batch_impl( self.open() while len(batch) < max_batch_size: try: - # TODO: This looses the transfer frame data + # TODO: This drops the transfer frame data _, message = self._received_messages.get_nowait() batch.append(message) self._received_messages.task_done() diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/link.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/link.py index 54a81e8fc989..ab3523566cb3 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/link.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/link.py @@ -81,7 +81,7 @@ def __init__(self, session, handle, name, role, **kwargs): self.network_trace = kwargs["network_trace"] self.network_trace_params = kwargs["network_trace_params"] - self.network_trace_params["link"] = self.name + self.network_trace_params["amqpLink"] = self.name self._session = session self._is_closed = False self._on_link_state_change = kwargs.get("on_link_state_change") @@ -157,16 +157,16 @@ def _outgoing_attach(self): properties=self.properties, ) if self.network_trace: - _LOGGER.info("-> %r", attach_frame, extra=self.network_trace_params) + _LOGGER.debug("-> %r", attach_frame, extra=self.network_trace_params) self._session._outgoing_attach(attach_frame) # pylint: disable=protected-access def _incoming_attach(self, frame): if self.network_trace: - _LOGGER.info("<- %r", AttachFrame(*frame), extra=self.network_trace_params) + _LOGGER.debug("<- %r", AttachFrame(*frame), extra=self.network_trace_params) if self._is_closed: raise ValueError("Invalid link") if not frame[5] or not frame[6]: - _LOGGER.info("Cannot get source or target. Detaching link") + _LOGGER.info("Cannot get source or target. Detaching link", extra=self.network_trace_params) self._set_state(LinkState.DETACHED) raise ValueError("Invalid link") self.remote_handle = frame[1] # handle @@ -188,7 +188,7 @@ def _incoming_attach(self, frame): frame[6] = Target(*frame[6]) self._on_attach(AttachFrame(*frame)) except Exception as e: # pylint: disable=broad-except - _LOGGER.warning("Callback for link attach raised error: %r", e) + _LOGGER.warning("Callback for link attach raised error: %r", e, extra=self.network_trace_params) def _outgoing_flow(self, **kwargs): flow_frame = { @@ -211,14 +211,14 @@ def _incoming_disposition(self, frame): def _outgoing_detach(self, close=False, error=None): detach_frame = DetachFrame(handle=self.handle, closed=close, error=error) if self.network_trace: - _LOGGER.info("-> %r", detach_frame, extra=self.network_trace_params) + _LOGGER.debug("-> %r", detach_frame, extra=self.network_trace_params) self._session._outgoing_detach(detach_frame) # pylint: disable=protected-access if close: self._is_closed = True def _incoming_detach(self, frame): if self.network_trace: - _LOGGER.info("<- %r", DetachFrame(*frame), extra=self.network_trace_params) + _LOGGER.debug("<- %r", DetachFrame(*frame), extra=self.network_trace_params) if self.state == LinkState.ATTACHED: self._outgoing_detach(close=frame[1]) # closed elif frame[1] and not self._is_closed and self.state in [LinkState.ATTACH_SENT, LinkState.ATTACH_RCVD]: @@ -253,7 +253,7 @@ def detach(self, close=False, error=None): self._outgoing_detach(close=close, error=error) self._set_state(LinkState.DETACH_SENT) except Exception as exc: # pylint: disable=broad-except - _LOGGER.info("An error occurred when detaching the link: %r", exc) + _LOGGER.info("An error occurred when detaching the link: %r", exc, extra=self.network_trace_params) self._set_state(LinkState.DETACHED) def flow(self, *, link_credit: Optional[int] = None, **kwargs) -> None: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/management_link.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/management_link.py index 87290435af9b..c5b1e6c0aa19 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/management_link.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/management_link.py @@ -39,12 +39,14 @@ def __init__(self, session, endpoint, **kwargs): self.state = ManagementLinkState.IDLE self._pending_operations = [] self._session = session + self._network_trace_params = kwargs.get('network_trace_params') self._request_link: SenderLink = session.create_sender_link( endpoint, source_address=endpoint, on_link_state_change=self._on_sender_state_change, send_settle_mode=SenderSettleMode.Unsettled, - rcv_settle_mode=ReceiverSettleMode.First + rcv_settle_mode=ReceiverSettleMode.First, + network_trace=kwargs.get("network_trace", False) ) self._response_link: ReceiverLink = session.create_receiver_link( endpoint, @@ -52,7 +54,8 @@ def __init__(self, session, endpoint, **kwargs): on_link_state_change=self._on_receiver_state_change, on_transfer=self._on_message_received, send_settle_mode=SenderSettleMode.Unsettled, - rcv_settle_mode=ReceiverSettleMode.First + rcv_settle_mode=ReceiverSettleMode.First, + network_trace=kwargs.get("network_trace", False) ) self._on_amqp_management_error = kwargs.get('on_amqp_management_error') self._on_amqp_management_open_complete = kwargs.get('on_amqp_management_open_complete') @@ -71,7 +74,12 @@ def __exit__(self, *args): self.close() def _on_sender_state_change(self, previous_state, new_state): - _LOGGER.info("Management link sender state changed: %r -> %r", previous_state, new_state) + _LOGGER.info( + "Management link sender state changed: %r -> %r", + previous_state, + new_state, + extra=self._network_trace_params + ) if new_state == previous_state: return if self.state == ManagementLinkState.OPENING: @@ -96,7 +104,12 @@ def _on_sender_state_change(self, previous_state, new_state): return def _on_receiver_state_change(self, previous_state, new_state): - _LOGGER.info("Management link receiver state changed: %r -> %r", previous_state, new_state) + _LOGGER.info( + "Management link receiver state changed: %r -> %r", + previous_state, + new_state, + extra=self._network_trace_params + ) if new_state == previous_state: return if self.state == ManagementLinkState.OPENING: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/management_operation.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/management_operation.py index d9e9080ea260..475c3424a897 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/management_operation.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/management_operation.py @@ -28,6 +28,11 @@ def __init__(self, session, endpoint='$management', **kwargs): self._session = session self._connection = self._session._connection + self._network_trace_params = { + "amqpConnection": self._session._connection._container_id, + "amqpSession": self._session.name, + "amqpLink": None + } self._mgmt_link = self._session.create_request_response_link_pair( endpoint=endpoint, on_amqp_management_open_complete=self._on_amqp_management_open_complete, @@ -61,22 +66,22 @@ def _on_execute_operation_complete( error=None ): _LOGGER.debug( - "mgmt operation completed, operation id: %r; operation_result: %r; status_code: %r; " - "status_description: %r, raw_message: %r, error: %r", + "Management operation completed, id: %r; result: %r; code: %r; description: %r, error: %r", operation_id, operation_result, status_code, status_description, - raw_message, - error + error, + extra=self._network_trace_params ) if operation_result in\ (ManagementExecuteOperationResult.ERROR, ManagementExecuteOperationResult.LINK_CLOSED): self._mgmt_error = error _LOGGER.error( - "Failed to complete mgmt operation due to error: %r. The management request message is: %r", - error, raw_message + "Failed to complete management operation due to error: %r.", + error, + extra=self._network_trace_params ) else: self._responses[operation_id] = (status_code, status_description, raw_message) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/receiver.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/receiver.py index 2e3773243f85..5713f51b4b8c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/receiver.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/receiver.py @@ -38,13 +38,13 @@ def _process_incoming_message(self, frame, message): try: return self._on_transfer(frame, message) except Exception as e: # pylint: disable=broad-except - _LOGGER.error("Handler function failed with error: %r", e) + _LOGGER.error("Transfer callback function failed with error: %r", e, extra=self.network_trace_params) return None def _incoming_attach(self, frame): super(ReceiverLink, self)._incoming_attach(frame) if frame[9] is None: # initial_delivery_count - _LOGGER.info("Cannot get initial-delivery-count. Detaching link") + _LOGGER.info("Cannot get initial-delivery-count. Detaching link", extra=self.network_trace_params) self._set_state(LinkState.DETACHED) # TODO: Send detach now? self.delivery_count = frame[9] self.current_link_credit = self.link_credit @@ -52,7 +52,7 @@ def _incoming_attach(self, frame): def _incoming_transfer(self, frame): if self.network_trace: - _LOGGER.debug("<- %r", TransferFrame(*frame), extra=self.network_trace_params) + _LOGGER.debug("<- %r", TransferFrame(payload=b"***", *frame[:-1]), extra=self.network_trace_params) self.current_link_credit -= 1 self.delivery_count += 1 self.received_delivery_id = frame[1] # delivery_id @@ -66,8 +66,6 @@ def _incoming_transfer(self, frame): self._received_payload = bytearray() else: message = decode_payload(frame[11]) - if self.network_trace: - _LOGGER.debug(" %r", message, extra=self.network_trace_params) delivery_state = self._process_incoming_message(frame, message) if not frame[4] and delivery_state: # settled self._outgoing_disposition( diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sender.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sender.py index 9ee708e9b1d6..26c78f5f9c17 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sender.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sender.py @@ -11,9 +11,6 @@ from ._encode import encode_payload from .link import Link from .constants import SessionTransferState, LinkDeliverySettleReason, LinkState, Role, SenderSettleMode, SessionState -from .performatives import ( - TransferFrame, -) from .error import AMQPLinkError, ErrorCondition, MessageException _LOGGER = logging.getLogger(__name__) @@ -29,13 +26,18 @@ def __init__(self, **kwargs): self.transfer_state = None self.timeout = kwargs.get("timeout") self.settled = kwargs.get("settled", False) + self._network_trace_params = kwargs.get('network_trace_params') def on_settled(self, reason, state): if self.on_delivery_settled and not self.settled: try: self.on_delivery_settled(reason, state) except Exception as e: # pylint:disable=broad-except - _LOGGER.warning("Message 'on_send_complete' callback failed: %r", e) + _LOGGER.warning( + "Message 'on_send_complete' callback failed: %r", + e, + extra=self._network_trace_params + ) self.settled = True @@ -75,7 +77,10 @@ def _incoming_flow(self, frame): rcv_delivery_count = frame[5] # delivery_count if frame[4] is not None: # handle if rcv_link_credit is None or rcv_delivery_count is None: - _LOGGER.info("Unable to get link-credit or delivery-count from incoming ATTACH. Detaching link.") + _LOGGER.info( + "Unable to get link-credit or delivery-count from incoming ATTACH. Detaching link.", + extra=self.network_trace_params + ) self._remove_pending_deliveries() self._set_state(LinkState.DETACHED) # TODO: Send detach now? else: @@ -99,12 +104,10 @@ def _outgoing_transfer(self, delivery): "batchable": None, "payload": output, } - if self.network_trace: - _LOGGER.debug( - "-> %r", TransferFrame(delivery_id="", **delivery.frame), extra=self.network_trace_params - ) - _LOGGER.debug(" %r", delivery.message, extra=self.network_trace_params) - self._session._outgoing_transfer(delivery) # pylint:disable=protected-access + self._session._outgoing_transfer( # pylint:disable=protected-access + delivery, + self.network_trace_params if self.network_trace else None + ) sent_and_settled = False if delivery.transfer_state == SessionTransferState.OKAY: self.delivery_count = delivery_count @@ -172,6 +175,7 @@ def send_transfer(self, message, *, send_async=False, **kwargs): timeout=kwargs.get("timeout"), message=message, settled=settled, + network_trace_params = self.network_trace_params ) if self.current_link_credit == 0 or send_async: self._pending_deliveries.append(delivery) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/session.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/session.py index ea36c5b1be1c..3582b2e64e48 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/session.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/session.py @@ -60,7 +60,7 @@ def __init__(self, connection, channel, **kwargs): self.idle_wait_time = kwargs.get("idle_wait_time", 0.1) self.network_trace = kwargs["network_trace"] self.network_trace_params = kwargs["network_trace_params"] - self.network_trace_params["session"] = self.name + self.network_trace_params["amqpSession"] = self.name self.links = {} self._connection = connection @@ -138,14 +138,14 @@ def _outgoing_begin(self): properties=self.properties, ) if self.network_trace: - _LOGGER.info("-> %r", begin_frame, extra=self.network_trace_params) + _LOGGER.debug("-> %r", begin_frame, extra=self.network_trace_params) self._connection._process_outgoing_frame( # pylint: disable=protected-access self.channel, begin_frame ) def _incoming_begin(self, frame): if self.network_trace: - _LOGGER.info("<- %r", BeginFrame(*frame), extra=self.network_trace_params) + _LOGGER.debug("<- %r", BeginFrame(*frame), extra=self.network_trace_params) self.handle_max = frame[4] # handle_max self.next_incoming_id = frame[1] # next_outgoing_id self.remote_incoming_window = frame[2] # incoming_window @@ -161,14 +161,14 @@ def _incoming_begin(self, frame): def _outgoing_end(self, error=None): end_frame = EndFrame(error=error) if self.network_trace: - _LOGGER.info("-> %r", end_frame, extra=self.network_trace_params) + _LOGGER.debug("-> %r", end_frame, extra=self.network_trace_params) self._connection._process_outgoing_frame( # pylint: disable=protected-access self.channel, end_frame ) def _incoming_end(self, frame): if self.network_trace: - _LOGGER.info("<- %r", EndFrame(*frame), extra=self.network_trace_params) + _LOGGER.debug("<- %r", EndFrame(*frame), extra=self.network_trace_params) if self.state not in [ SessionState.END_RCVD, SessionState.END_SENT, @@ -198,6 +198,10 @@ def _incoming_attach(self, frame): try: outgoing_handle = self._get_next_output_handle() except ValueError: + _LOGGER.error( + "Unable to attach new link - cannot allocate more handles.", + extra=self.network_trace_params + ) # detach the link that would have been set. self.links[frame[0].decode("utf-8")].detach( error=AMQPError( @@ -220,8 +224,13 @@ def _incoming_attach(self, frame): self.links[frame[0]] = new_link self._output_handles[outgoing_handle] = new_link self._input_handles[frame[1]] = new_link - except ValueError: + except ValueError as e: # Reject Link + _LOGGER.error( + "Unable to attach new link: %r", + e, + extra=self.network_trace_params + ) self._input_handles[frame[1]].detach() def _outgoing_flow(self, frame=None): @@ -236,14 +245,14 @@ def _outgoing_flow(self, frame=None): ) flow_frame = FlowFrame(**link_flow) if self.network_trace: - _LOGGER.info("-> %r", flow_frame, extra=self.network_trace_params) + _LOGGER.debug("-> %r", flow_frame, extra=self.network_trace_params) self._connection._process_outgoing_frame( # pylint: disable=protected-access self.channel, flow_frame ) def _incoming_flow(self, frame): if self.network_trace: - _LOGGER.info("<- %r", FlowFrame(*frame), extra=self.network_trace_params) + _LOGGER.debug("<- %r", FlowFrame(*frame), extra=self.network_trace_params) self.next_incoming_id = frame[2] # next_outgoing_id remote_incoming_id = ( frame[0] or self.next_outgoing_id @@ -263,7 +272,7 @@ def _incoming_flow(self, frame): ): link._incoming_flow(frame) # pylint: disable=protected-access - def _outgoing_transfer(self, delivery): + def _outgoing_transfer(self, delivery, network_trace_params): if self.state != SessionState.MAPPED: delivery.transfer_state = SessionTransferState.ERROR if self.remote_incoming_window <= 0: @@ -300,11 +309,24 @@ def _outgoing_transfer(self, delivery): "resume": delivery.frame["resume"], "aborted": delivery.frame["aborted"], "batchable": delivery.frame["batchable"], - "payload": payload[start_idx : start_idx + available_frame_size], "delivery_id": self.next_outgoing_id, } + if network_trace_params: + # We determine the logging for the outgoing Transfer frames based on the source + # Link configuration rather than the Session, because it's only at the Session + # level that we can determine how many outgoing frames are needed and their + # delivery IDs. + # TODO: Obscuring the payload for now to investigate the potential for leaks. + _LOGGER.debug( + "-> %r", TransferFrame(payload=b"***", **tmp_delivery_frame), + extra=network_trace_params + ) self._connection._process_outgoing_frame( # pylint: disable=protected-access - self.channel, TransferFrame(**tmp_delivery_frame) + self.channel, + TransferFrame( + payload=payload[start_idx : start_idx + available_frame_size], + **tmp_delivery_frame + ) ) start_idx += available_frame_size remaining_payload_cnt -= available_frame_size @@ -321,11 +343,21 @@ def _outgoing_transfer(self, delivery): "resume": delivery.frame["resume"], "aborted": delivery.frame["aborted"], "batchable": delivery.frame["batchable"], - "payload": payload[start_idx:], "delivery_id": self.next_outgoing_id, } + if network_trace_params: + # We determine the logging for the outgoing Transfer frames based on the source + # Link configuration rather than the Session, because it's only at the Session + # level that we can determine how many outgoing frames are needed and their + # delivery IDs. + # TODO: Obscuring the payload for now to investigate the potential for leaks. + _LOGGER.debug( + "-> %r", TransferFrame(payload=b"***", **tmp_delivery_frame), + extra=network_trace_params + ) self._connection._process_outgoing_frame( # pylint: disable=protected-access - self.channel, TransferFrame(**tmp_delivery_frame) + self.channel, + TransferFrame(payload=payload[start_idx:], **tmp_delivery_frame) ) self.next_outgoing_id += 1 self.remote_incoming_window -= 1 @@ -342,6 +374,10 @@ def _incoming_transfer(self, frame): frame ) except KeyError: + _LOGGER.error( + "Received Transfer frame on unattached link. Ending session.", + extra=self.network_trace_params + ) self._set_state(SessionState.DISCARDING) self.end( error=AMQPError( @@ -350,6 +386,7 @@ def _incoming_transfer(self, frame): """Handle is not currently associated with an attached link""", ) ) + return if self.incoming_window == 0: self.incoming_window = self.target_incoming_window self._outgoing_flow() @@ -361,7 +398,7 @@ def _outgoing_disposition(self, frame): def _incoming_disposition(self, frame): if self.network_trace: - _LOGGER.info( + _LOGGER.debug( "<- %r", DispositionFrame(*frame), extra=self.network_trace_params ) for link in self._input_handles.values(): @@ -427,7 +464,7 @@ def end(self, error=None, wait=False): self._set_state(new_state) self._wait_for_response(wait, SessionState.UNMAPPED) except Exception as exc: # pylint: disable=broad-except - _LOGGER.info("An error occurred when ending the session: %r", exc) + _LOGGER.info("An error occurred when ending the session: %r", exc, extra=self.network_trace_params) self._set_state(SessionState.UNMAPPED) def create_receiver_link(self, source_address, **kwargs): @@ -463,5 +500,6 @@ def create_request_response_link_pair(self, endpoint, **kwargs): self, endpoint, network_trace=kwargs.pop("network_trace", self.network_trace), + network_trace_params=dict(self.network_trace_params), **kwargs, ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_eventprocessor/event_processor.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_eventprocessor/event_processor.py index 3b666f697329..5a3408912afa 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_eventprocessor/event_processor.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_eventprocessor/event_processor.py @@ -251,7 +251,6 @@ async def _close_consumer(self, partition_context): await self._ownership_manager.release_ownership(partition_id) finally: if partition_id in self._tasks: - self._tasks[partition_id].cancel() del self._tasks[partition_id] async def _receive( diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py index c2aac86871c9..8a2432f543ea 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py @@ -173,33 +173,38 @@ async def _callback_task(consumer, batch, max_batch_size, max_wait_time): @staticmethod async def _receive_task(consumer): - max_retries = consumer._client._config.max_retries # pylint:disable=protected-access + # pylint:disable=protected-access + max_retries = consumer._client._config.max_retries retried_times = 0 - while retried_times <= max_retries and consumer._callback_task_run: # pylint: disable=protected-access - try: - await consumer._open() # pylint: disable=protected-access - await cast(ReceiveClientAsync, consumer._handler).do_work_async(batch=consumer._prefetch) # pylint: disable=protected-access - except asyncio.CancelledError: # pylint: disable=try-except-raise - raise - except Exception as exception: # pylint: disable=broad-except - if ( - isinstance(exception, errors.AMQPLinkError) - and exception.condition == errors.ErrorCondition.LinkStolen # pylint: disable=no-member - ): - raise await consumer._handle_exception(exception) # pylint: disable=protected-access - if not consumer.running: # exit by close - return - if consumer._last_received_event: # pylint: disable=protected-access - consumer._offset = consumer._last_received_event.offset # pylint: disable=protected-access - last_exception = await consumer._handle_exception(exception) # pylint: disable=protected-access - retried_times += 1 - if retried_times > max_retries: - _LOGGER.info( - "%r operation has exhausted retry. Last exception: %r.", - consumer._name, # pylint: disable=protected-access - last_exception, - ) - raise last_exception + running = True + try: + while retried_times <= max_retries and running and consumer._callback_task_run: + try: + await consumer._open() # pylint: disable=protected-access + running = await cast(ReceiveClientAsync, consumer._handler).do_work_async(batch=consumer._prefetch) + except asyncio.CancelledError: # pylint: disable=try-except-raise + raise + except Exception as exception: # pylint: disable=broad-except + if ( + isinstance(exception, errors.AMQPLinkError) + and exception.condition == errors.ErrorCondition.LinkStolen # pylint: disable=no-member + ): + raise await consumer._handle_exception(exception) + if not consumer.running: # exit by close + return + if consumer._last_received_event: + consumer._offset = consumer._last_received_event.offset + last_exception = await consumer._handle_exception(exception) + retried_times += 1 + if retried_times > max_retries: + _LOGGER.info( + "%r operation has exhausted retry. Last exception: %r.", + consumer._name, + last_exception, + ) + raise last_exception + finally: + consumer._callback_task_run = False @staticmethod async def message_received_async(consumer, message: Message) -> None: @@ -225,19 +230,12 @@ async def receive_messages_async(consumer, batch, max_batch_size, max_wait_time) tasks = [callback_task, receive_task] try: - for task in asyncio.as_completed(tasks): - try: - await task - await asyncio.sleep(0) - except Exception: # pylint: disable=broad-except - consumer._callback_task_run = False - for task in tasks: - if task.done() and task.exception(): - raise task.exception() - except asyncio.CancelledError: + await asyncio.gather(*tasks) + finally: consumer._callback_task_run = False - await asyncio.sleep(0) - raise + for t in tasks: + if not t.done(): + await asyncio.wait([t], timeout=1) @staticmethod async def create_token_auth_async(auth_uri, get_token, token_type, config, **kwargs): diff --git a/sdk/eventhub/azure-eventhub/tests/perfstress_tests/_test_base.py b/sdk/eventhub/azure-eventhub/tests/perfstress_tests/_test_base.py index dbef4b8f73c2..03b107a8fc5c 100644 --- a/sdk/eventhub/azure-eventhub/tests/perfstress_tests/_test_base.py +++ b/sdk/eventhub/azure-eventhub/tests/perfstress_tests/_test_base.py @@ -5,9 +5,10 @@ import asyncio from uuid import uuid4 +from datetime import datetime from azure_devtools.perfstress_tests import BatchPerfTest, EventPerfTest, get_random_bytes -from azure.eventhub import EventHubProducerClient, EventHubConsumerClient, EventData +from azure.eventhub import EventHubProducerClient, EventHubConsumerClient, EventData, TransportType from azure.eventhub.aio import ( EventHubProducerClient as AsyncEventHubProducerClient, EventHubConsumerClient as AsyncEventHubConsumerClient @@ -34,25 +35,45 @@ def __init__(self, arguments): self.checkpoint_store = BlobCheckpointStore.from_connection_string(storage_connection_str, self.container_name) self.async_checkpoint_store = AsyncBlobCheckpointStore.from_connection_string(storage_connection_str, self.container_name) + transport_type = TransportType.AmqpOverWebsocket if arguments.transport_type==1 else TransportType.Amqp + self.consumer = EventHubConsumerClient.from_connection_string( connection_string, _EventHubProcessorTest.consumer_group, eventhub_name=eventhub_name, checkpoint_store=self.checkpoint_store, - load_balancing_strategy=arguments.load_balancing_strategy + load_balancing_strategy=arguments.load_balancing_strategy, + transport_type=transport_type, + uamqp_transport=arguments.uamqp_transport ) self.async_consumer = AsyncEventHubConsumerClient.from_connection_string( connection_string, _EventHubProcessorTest.consumer_group, eventhub_name=eventhub_name, checkpoint_store=self.async_checkpoint_store, - load_balancing_strategy=arguments.load_balancing_strategy + load_balancing_strategy=arguments.load_balancing_strategy, + transport_type=transport_type, + uamqp_transport=arguments.uamqp_transport ) if arguments.preload: - self.async_producer = AsyncEventHubProducerClient.from_connection_string(connection_string, eventhub_name=eventhub_name) + self.data = get_random_bytes(self.args.event_size) + self.async_producer = AsyncEventHubProducerClient.from_connection_string(connection_string, eventhub_name=eventhub_name, transport_type=transport_type, uamqp_transport=arguments.uamqp_transport) + + def _build_event(self): + event = EventData(self.data) + if self.args.event_extra: + event.raw_amqp_message.header.first_acquirer = True + event.raw_amqp_message.properties.subject = 'perf' + event.properties = { + "key1": b"data", + "key2": 42, + "key3": datetime.now(), + "key4": "foobar", + "key5": uuid4() + } + return event async def _preload_eventhub(self): - data = get_random_bytes(self.args.event_size) async with self.async_producer as producer: partitions = await producer.get_partition_ids() total_events = 0 @@ -65,13 +86,13 @@ async def _preload_eventhub(self): batch = await producer.create_batch() for i in range(events_to_add): try: - batch.add(EventData(data)) + batch.add(self._build_event()) except ValueError: # Batch full await producer.send_batch(batch) print(f"Loaded {i} of {events_to_add} events.") batch = await producer.create_batch() - batch.add(EventData(data)) + batch.add(self._build_event()) await producer.send_batch(batch) print(f"Finished loading {events_to_add} events.") @@ -120,6 +141,10 @@ def add_arguments(parser): parser.add_argument('--processing-delay-strategy', nargs='?', type=str, help="Whether to 'sleep' or 'spin' during processing delay. Default is 'sleep'.", default='sleep') parser.add_argument('--preload', nargs='?', type=int, help='Ensure the specified number of events are available across all partitions. Default is 0.', default=0) parser.add_argument('--use-storage-checkpoint', action="store_true", help="Use Blob storage for checkpointing. Default is False (in-memory checkpointing).", default=False) + parser.add_argument('--uamqp-transport', action="store_true", help="Switch to use uamqp transport. Default is False (pyamqp).", default=False) + parser.add_argument('--transport-type', nargs='?', type=int, help="Use Amqp (0) or Websocket (1) transport type. Default is Amqp.", default=0) + parser.add_argument('--event-extra', action="store_true", help="Add properties to the events to increase payload and serialization. Default is False.", default=False) + class _SendTest(BatchPerfTest): @@ -129,13 +154,20 @@ def __init__(self, arguments): super().__init__(arguments) connection_string = self.get_from_env("AZURE_EVENTHUB_CONNECTION_STRING") eventhub_name = self.get_from_env("AZURE_EVENTHUB_NAME") + + transport_type = TransportType.AmqpOverWebsocket if arguments.transport_type==1 else TransportType.Amqp + self.producer = EventHubProducerClient.from_connection_string( connection_string, - eventhub_name=eventhub_name + eventhub_name=eventhub_name, + transport_type=transport_type, + uamqp_transport=arguments.uamqp_transport ) self.async_producer = AsyncEventHubProducerClient.from_connection_string( connection_string, - eventhub_name=eventhub_name + eventhub_name=eventhub_name, + transport_type=transport_type, + uamqp_transport=arguments.uamqp_transport ) async def setup(self): @@ -156,3 +188,6 @@ def add_arguments(parser): super(_SendTest, _SendTest).add_arguments(parser) parser.add_argument('--event-size', nargs='?', type=int, help='Size of event body (in bytes). Defaults to 100 bytes', default=100) parser.add_argument('--batch-size', nargs='?', type=int, help='The number of events that should be included in each batch. Defaults to 100', default=100) + parser.add_argument('--uamqp-transport', action="store_true", help="Switch to use uamqp transport. Default is False (pyamqp).", default=False) + parser.add_argument('--transport-type', nargs='?', type=int, help="Use Amqp (0) or Websocket (1) transport type. Default is Amqp.", default=0) + parser.add_argument('--event-extra', action="store_true", help="Add properties to the events to increase payload and serialization. Default is False.", default=False) diff --git a/sdk/eventhub/azure-eventhub/tests/perfstress_tests/process_events_batch.py b/sdk/eventhub/azure-eventhub/tests/perfstress_tests/process_events_batch.py index 224652255a8e..15662ba70aed 100644 --- a/sdk/eventhub/azure-eventhub/tests/perfstress_tests/process_events_batch.py +++ b/sdk/eventhub/azure-eventhub/tests/perfstress_tests/process_events_batch.py @@ -28,7 +28,7 @@ def process_events_sync(self, partition_context, events): pass # Consume properties and body. - _ = [(e.properties, e.body) for e in events] + _ = [(list(e.body), str(e)) for e in events] if self.args.checkpoint_interval: self._partition_event_count[partition_context.partition_id] += len(events) @@ -51,9 +51,9 @@ async def process_events_async(self, partition_context, events): starttime = time.time() while (time.time() - starttime) < delay_in_seconds: pass - + # Consume properties and body. - _ = [(e.properties, e.body) for e in events] + _ = [(list(e.body), str(e)) for e in events] if self.args.checkpoint_interval: self._partition_event_count[partition_context.partition_id] += len(events) @@ -66,9 +66,11 @@ async def process_events_async(self, partition_context, events): await self.error_raised_async(e) def process_error_sync(self, _, error): + print(error) self.error_raised_sync(error) async def process_error_async(self, _, error): + print(error) await self.error_raised_async(error) def start_events_sync(self) -> None: @@ -100,4 +102,4 @@ async def start_events_async(self) -> None: @staticmethod def add_arguments(parser): super(ProcessEventsBatchTest, ProcessEventsBatchTest).add_arguments(parser) - parser.add_argument('--max-batch-size', nargs='?', type=int, help='Maximum number of events to process in a single batch. Defaults to 100.', default=100) + parser.add_argument('--max-batch-size', nargs='?', type=int, help='Maximum number of events to process in a single batch. Defaults to 100.', default=100) \ No newline at end of file diff --git a/sdk/eventhub/azure-eventhub/tests/perfstress_tests/send_event_batch.py b/sdk/eventhub/azure-eventhub/tests/perfstress_tests/send_event_batch.py index 538f2cc271d9..2a2bf0a506f7 100644 --- a/sdk/eventhub/azure-eventhub/tests/perfstress_tests/send_event_batch.py +++ b/sdk/eventhub/azure-eventhub/tests/perfstress_tests/send_event_batch.py @@ -3,6 +3,9 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from datetime import datetime +from uuid import uuid4 + from ._test_base import _SendTest from azure_devtools.perfstress_tests import get_random_bytes @@ -15,16 +18,30 @@ def __init__(self, arguments): super().__init__(arguments) self.data = get_random_bytes(self.args.event_size) + def _build_event(self): + event = EventData(self.data) + if self.args.event_extra: + event.raw_amqp_message.header.first_acquirer = True + event.raw_amqp_message.properties.subject = 'perf' + event.properties = { + "key1": b"data", + "key2": 42, + "key3": datetime.now(), + "key4": "foobar", + "key5": uuid4() + } + return event + def run_batch_sync(self): batch = self.producer.create_batch() for _ in range(self.args.batch_size): - batch.add(EventData(self.data)) + batch.add(self._build_event()) self.producer.send_batch(batch) return self.args.batch_size async def run_batch_async(self): batch = await self.async_producer.create_batch() for _ in range(self.args.batch_size): - batch.add(EventData(self.data)) + batch.add(self._build_event()) await self.async_producer.send_batch(batch) return self.args.batch_size diff --git a/sdk/eventhub/azure-eventhub/tests/perfstress_tests/send_events.py b/sdk/eventhub/azure-eventhub/tests/perfstress_tests/send_events.py index 8be9b87df685..851a84be1540 100644 --- a/sdk/eventhub/azure-eventhub/tests/perfstress_tests/send_events.py +++ b/sdk/eventhub/azure-eventhub/tests/perfstress_tests/send_events.py @@ -2,6 +2,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from datetime import datetime +from uuid import uuid4 from ._test_base import _SendTest @@ -15,20 +17,34 @@ def __init__(self, arguments): super().__init__(arguments) self.data = get_random_bytes(self.args.event_size) + def _build_event(self): + event = EventData(self.data) + if self.args.event_extra: + event.raw_amqp_message.header.first_acquirer = True + event.raw_amqp_message.properties.subject = 'perf' + event.properties = { + "key1": b"data", + "key2": 42, + "key3": datetime.now(), + "key4": "foobar", + "key5": uuid4() + } + return event + def run_batch_sync(self): if self.args.batch_size > 1: self.producer.send_batch( - [EventData(self.data) for _ in range(self.args.batch_size)] + [self._build_event() for _ in range(self.args.batch_size)] ) else: - self.producer.send_event(EventData(self.data)) + self.producer.send_event(self._build_event()) return self.args.batch_size async def run_batch_async(self): if self.args.batch_size > 1: await self.async_producer.send_batch( - [EventData(self.data) for _ in range(self.args.batch_size)] + [self._build_event() for _ in range(self.args.batch_size)] ) else: - await self.async_producer.send_event(EventData(self.data)) + await self.async_producer.send_event(self._build_event()) return self.args.batch_size