Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EventHubs] sync with SB pyamqp #34407

Merged
merged 9 commits into from
Mar 27, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -337,13 +337,13 @@ def _get_next_outgoing_channel(self) -> int:

def _outgoing_empty(self) -> None:
"""Send an empty frame to prevent the connection from reaching an idle timeout."""
if self._network_trace:
_LOGGER.debug("-> EmptyFrame()", extra=self._network_trace_params)
if self._error:
raise self._error

try:
if self._can_write():
if self._network_trace:
_LOGGER.debug("-> EmptyFrame()", extra=self._network_trace_params)
self._transport.write(EMPTY_FRAME)
self._last_frame_sent_time = time.time()
except (OSError, IOError, SSLError, socket.error) as exc:
Expand Down Expand Up @@ -516,7 +516,7 @@ def _incoming_close(self, channel: int, frame: Tuple[Any, ...]) -> None:
self._error = AMQPConnectionError(
condition=frame[0][0], description=frame[0][1], info=frame[0][2]
)
_LOGGER.error(
_LOGGER.warning(
"Connection closed with error: %r", frame[0],
extra=self._network_trace_params
)
Expand Down Expand Up @@ -667,7 +667,10 @@ def _process_outgoing_frame(self, channel: int, frame) -> None:
ConnectionState.OPEN_SENT,
ConnectionState.OPENED,
]:
raise ValueError("Connection not open.")
raise AMQPConnectionError(
ErrorCondition.SocketError,
description="Connection not open."
)
now = time.time()
if get_local_timeout(
now,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,11 @@ def _write(self, s):
"""Write a string out to the SSL socket fully.
:param str s: The string to write.
"""
write = self.sock.send
try:
write = self.sock.send
except AttributeError:
raise IOError("Socket has already been closed.") from None

while s:
try:
n = write(s)
Expand All @@ -659,7 +663,7 @@ def _write(self, s):
# None.
n = 0
if not n:
raise IOError("Socket closed")
raise IOError("Socket closed.")
s = s[n:]

def negotiate(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ async def _on_execute_operation_complete(
async def _update_status(self):
if self.auth_state in (CbsAuthState.OK, CbsAuthState.REFRESH_REQUIRED):
is_expired, is_refresh_required = check_expiration_and_refresh_status(
self._expires_on, self._refresh_window
self._expires_on, self._refresh_window # type: ignore
)
_LOGGER.debug(
"CBS status check: state == %r, expired == %r, refresh required == %r",
Expand Down Expand Up @@ -235,13 +235,13 @@ async def update_token(self) -> None:
elif isinstance(access_token.token, str):
self._token = access_token.token
else:
raise ValueError("Token must be either bytes or string.")
raise ValueError("Token must be a string or bytes.")
if isinstance(self._auth.token_type, bytes):
token_type = self._auth.token_type.decode()
elif isinstance(self._auth.token_type, str):
token_type = self._auth.token_type
else:
raise ValueError("Token type must be either bytes or string.")
raise ValueError("Token type must be a string or bytes.")

self._token_put_time = int(utc_now().timestamp())
if self._token and token_type:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,6 @@ async def _keep_alive_async(self):
current_time = time.time()
elapsed_time = current_time - start_time
if elapsed_time >= self._keep_alive_interval:
_logger.debug(
"Keeping %r connection alive.",
self.__class__.__name__,
extra=self._network_trace_params
)
await asyncio.shield(self._connection.listen(wait=self._socket_timeout,
batch=self._link.current_link_credit))
start_time = current_time
Expand Down Expand Up @@ -723,7 +718,7 @@ async def _client_ready_async(self):
if not self._link:
self._link = self._session.create_receiver_link(
source_address=self.source,
link_credit=self._link_credit,
link_credit=0, # link_credit=0 on flow frame sent before client is ready
send_settle_mode=self._send_settle_mode,
rcv_settle_mode=self._receive_settle_mode,
max_message_size=self._max_message_size,
Expand All @@ -748,7 +743,7 @@ async def _client_run_async(self, **kwargs):
"""
try:
if self._link.current_link_credit <= 0:
await self._link.flow()
await self._link.flow(link_credit=self._link_credit)
await self._connection.listen(wait=self._socket_timeout, **kwargs)
except ValueError:
_logger.info("Timeout reached, closing receiver.", extra=self._network_trace_params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,14 +345,13 @@ def _get_next_outgoing_channel(self) -> int:

async def _outgoing_empty(self) -> None:
"""Send an empty frame to prevent the connection from reaching an idle timeout."""
if self._network_trace:
_LOGGER.debug("-> EmptyFrame()", extra=self._network_trace_params)

if self._error:
raise self._error

try:
if self._can_write():
if self._network_trace:
_LOGGER.debug("-> EmptyFrame()", extra=self._network_trace_params)
await self._transport.write(EMPTY_FRAME)
self._last_frame_sent_time = time.time()
except (OSError, IOError, SSLError, socket.error) as exc:
Expand Down Expand Up @@ -533,7 +532,7 @@ async def _incoming_close(self, channel: int, frame: Tuple[Any, ...]) -> None:
self._error = AMQPConnectionError(
condition=frame[0][0], description=frame[0][1], info=frame[0][2]
)
_LOGGER.error(
_LOGGER.warning(
"Connection closed with error: %r", frame[0],
extra=self._network_trace_params
)
Expand Down Expand Up @@ -682,7 +681,10 @@ async def _process_outgoing_frame(self, channel: int, frame) -> None:
ConnectionState.OPEN_SENT,
ConnectionState.OPENED,
]:
raise ValueError("Connection not open.")
raise AMQPConnectionError(
ErrorCondition.SocketError,
description="Connection not open."
)
now = time.time()
if get_local_timeout(
now,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def update_token(self) -> None:
utc_from_timestamp(self._expires_on),
)

def handle_token(self) -> bool: # pylint: disable=inconsistent-return-statements
def handle_token(self) -> bool: # pylint: disable=inconsistent-return-statements
if not self._cbs_link_ready():
return False
self._update_status()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ def _keep_alive(self):
current_time = time.time()
elapsed_time = current_time - start_time
if elapsed_time >= self._keep_alive_interval:
_logger.debug("Keeping %r connection alive.", self.__class__.__name__)
self._connection.listen(wait=self._socket_timeout, batch=self._link.current_link_credit)
start_time = current_time
time.sleep(1)
Expand Down Expand Up @@ -827,7 +826,7 @@ def _client_ready(self):
if not self._link:
self._link = self._session.create_receiver_link(
source_address=self.source,
link_credit=self._link_credit, # link_credit=0 on flow frame sent before client is ready
link_credit=0, # link_credit=0 on flow frame sent before client is ready
send_settle_mode=self._send_settle_mode,
rcv_settle_mode=self._receive_settle_mode,
max_message_size=self._max_message_size,
Expand All @@ -852,7 +851,7 @@ def _client_run(self, **kwargs):
"""
try:
if self._link.current_link_credit <= 0:
self._link.flow()
self._link.flow(link_credit=self._link_credit)
self._connection.listen(wait=self._socket_timeout, **kwargs)
except ValueError:
_logger.info("Timeout reached, closing receiver.", extra=self._network_trace_params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Header(NamedTuple):
This field contains the relative Message priority. Higher numbers indicate higher priority Messages.
Messages with higher priorities MAY be delivered before those with lower priorities. An AMQP intermediary
implementing distinct priority levels MUST do so in the following manner:

- If n distince priorities are implemented and n is less than 10 - priorities 0 to (5 - ceiling(n/2))
MUST be treated equivalently and MUST be the lowest effective priority. The priorities (4 + fioor(n/2))
and above MUST be treated equivalently and MUST be the highest effective priority. The priorities
Expand Down Expand Up @@ -184,7 +184,7 @@ class Message(NamedTuple):
delivery_annotations: Optional[Dict[Union[str, bytes], Any]] = None
message_annotations: Optional[Dict[Union[str, bytes], Any]] = None
properties: Optional[Properties] = None
application_properties: Optional[Dict[Union[str, bytes], Any]] = None # TODO: make not read-only
application_properties: Optional[Dict[Union[str, bytes], Any]] = None
data: Optional[bytes] = None
sequence: Optional[List[Any]] = None
value: Optional[Any] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ async def on_event_received(event):
assert consumer._handler._connection._state == uamqp.c_uamqp.ConnectionState.DISCARDING
await consumer.receive(batch=False, max_batch_size=1, max_wait_time=10)
else:
await consumer._handler.do_work_async()
with pytest.raises(error.AMQPConnectionError):
await consumer._handler.do_work_async()
assert consumer._handler._connection.state == constants.ConnectionState.END
try:
await asyncio.wait_for(consumer.receive(), timeout=10)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def on_event_received(event):
consumer._handler.do_work()
assert consumer._handler._connection._state == uamqp.c_uamqp.ConnectionState.DISCARDING
else:
consumer._handler.do_work()
with pytest.raises(error.AMQPConnectionError):
swathipil marked this conversation as resolved.
Show resolved Hide resolved
consumer._handler.do_work()
assert consumer._handler._connection.state == constants.ConnectionState.END

duration = 10
Expand Down
Loading