From 34ecd47163ebf25a374530ab3066036e4831c00f Mon Sep 17 00:00:00 2001 From: Carina Antunes Date: Fri, 22 Jul 2022 17:58:22 +0200 Subject: [PATCH] fix race condition between ack/nack and disconnect --- Makefile | 4 +- stomp/connect.py | 6 +- stomp/transport.py | 85 +++++++++++++++++++------- tests/setup.ini | 6 +- tests/test_disconnect_wait.py | 110 ++++++++++++++++++++++++++++++++++ 5 files changed, 185 insertions(+), 26 deletions(-) create mode 100644 tests/test_disconnect_wait.py diff --git a/Makefile b/Makefile index c8524b1c..78d077bf 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ PYTHON_VERSION_MAJOR:=$(shell $(PYTHON) -c "import sys;print(sys.version_info[0] PLATFORM := $(shell uname) VERSION :=$(shell poetry version | sed 's/stomp.py\s*//g' | sed 's/\./, /g') SHELL=/bin/bash -ARTEMIS_VERSION=2.22.0 +ARTEMIS_VERSION=2.23.1 TEST_CMD := $(shell podman network exists stomptest &> /dev/null && echo "podman unshare --rootless-netns poetry" || echo "poetry") all: test install @@ -77,7 +77,7 @@ docker-image: docker/tmp/activemq-artemis-bin.tar.gz ssl-setup run-docker: - docker run --add-host="my.example.com:127.0.0.1" --add-host="my.example.org:127.0.0.1" --add-host="my.example.net:127.0.0.1" -d -p 61613:61613 -p 62613:62613 -p 62614:62614 -p 63613:63613 -p 64613:64613 --name stomppy -it stomppy + docker run --add-host="my.example.com:127.0.0.1" --add-host="my.example.org:127.0.0.1" --add-host="my.example.net:127.0.0.1" -d -p 61613:61613 -p 62613:62613 -p 62614:62614 -p 63613:63613 -p 64613:64613 -p 8161:8161 --name stomppy -it stomppy docker ps docker exec -it stomppy /bin/sh -c "/etc/init.d/activemq start" docker exec -it stomppy /bin/sh -c "/etc/init.d/stompserver start" diff --git a/stomp/connect.py b/stomp/connect.py index c1858f1d..178fa760 100644 --- a/stomp/connect.py +++ b/stomp/connect.py @@ -150,14 +150,18 @@ def connect(self, *args, **kwargs): self.transport.start() Protocol11.connect(self, *args, **kwargs) - def disconnect(self, receipt=None, headers=None, **keyword_headers): + def disconnect(self, receipt=None, headers=None, wait=False, **keyword_headers): """ Call the protocol disconnection, and then stop the transport itself. :param str receipt: the receipt to use with the disconnect :param dict headers: a map of any additional headers to send with the disconnection + :param bool wait: wait for the started messages to finish ack/nack before disconnection :param keyword_headers: any additional headers to send with the disconnection """ + if wait: + self.transport.begin_stop() + Protocol11.disconnect(self, receipt, headers, **keyword_headers) if receipt is not None: self.transport.stop() diff --git a/stomp/transport.py b/stomp/transport.py index 165aded6..1162d8f9 100644 --- a/stomp/transport.py +++ b/stomp/transport.py @@ -59,6 +59,8 @@ class BaseTransport(stomp.listener.Publisher): __content_length_re = re.compile(b"^content-length[:]\\s*(?P[0-9]+)", re.MULTILINE) def __init__(self, auto_decode=True, encoding="utf-8", is_eol_fc=is_eol_default): + self.__receiver_thread_sending_condition = threading.Condition() + self.__receiver_thread_sent = True self.__recvbuf = b"" self.listeners = {} self.running = False @@ -79,7 +81,7 @@ def __init__(self, auto_decode=True, encoding="utf-8", is_eol_fc=is_eol_default) self.__listeners_change_condition = threading.Condition() self.__receiver_thread_exit_condition = threading.Condition() self.__receiver_thread_exited = False - self.__send_wait_condition = threading.Condition() + self.__receipt_wait_condition = threading.Condition() self.__connect_wait_condition = threading.Condition() self.__auto_decode = auto_decode self.__encoding = encoding @@ -112,6 +114,13 @@ def start(self): logging.info("Created thread %s using func %s", receiver_thread, self.create_thread_fc) self.notify("connecting") + def begin_stop(self): + """ + Begin stop of the connection. Stops reading new messages but keep thread to finish ack/nack of messages. + """ + # emit stop reading new messages + self.running = False + def stop(self): """ Stop the connection. Performs a clean shutdown by waiting for the @@ -206,9 +215,10 @@ def notify(self, frame_type, frame=None): # logic for wait-on-receipt notification receipt = frame.headers["receipt-id"] receipt_value = self.__receipts.get(receipt) - with self.__send_wait_condition: + + with self.__receipt_wait_condition: self.set_receipt(receipt, None) - self.__send_wait_condition.notify() + self.__receipt_wait_condition.notifyAll() if receipt_value == CMD_DISCONNECT: self.set_connected(False) @@ -232,7 +242,7 @@ def notify(self, frame_type, frame=None): if not notify_func: logging.debug("listener %s has no method on_%s", listener, frame_type) continue - if frame_type in ("heartbeat", "disconnected"): + if frame_type in ("disconnecting", "heartbeat", "disconnected"): notify_func() continue if frame_type == "connecting": @@ -252,26 +262,36 @@ def transmit(self, frame): :param Frame frame: the Frame object to transmit """ - with self.__listeners_change_condition: - listeners = sorted(self.listeners.items()) + with self.__receiver_thread_sending_condition: + self.__receiver_thread_sent = False + self.__receiver_thread_sending_condition.notify_all() - for (_, listener) in listeners: - try: - listener.on_send(frame) - except AttributeError: - continue + try: + with self.__listeners_change_condition: + listeners = sorted(self.listeners.items()) - if frame.cmd == CMD_DISCONNECT and HDR_RECEIPT in frame.headers: - self.__disconnect_receipt = frame.headers[HDR_RECEIPT] + for (_, listener) in listeners: + try: + listener.on_send(frame) + except AttributeError: + continue - lines = convert_frame(frame) - packed_frame = pack(lines) + if frame.cmd == CMD_DISCONNECT and HDR_RECEIPT in frame.headers: + self.__disconnect_receipt = frame.headers[HDR_RECEIPT] - if logging.isEnabledFor(logging.DEBUG): - logging.debug("Sending frame: %s", clean_lines(lines)) - else: - logging.info("Sending frame: %r", frame.cmd or "heartbeat") - self.send(packed_frame) + lines = convert_frame(frame) + packed_frame = pack(lines) + + if logging.isEnabledFor(logging.DEBUG): + logging.debug("Sending frame: %s", clean_lines(lines)) + else: + logging.info("Sending frame: %r", frame.cmd or "heartbeat") + self.send(packed_frame) + + finally: + with self.__receiver_thread_sending_condition: + self.__receiver_thread_sent = True + self.__receiver_thread_sending_condition.notify_all() def send(self, encoded_frame): """ @@ -323,6 +343,21 @@ def wait_for_connection(self, timeout=None): if not self.running or not self.is_connected(): raise exception.ConnectFailedException() + def wait_for_receipt(self, receipt_id, timeout=None): + """ + Wait until we've received a receipt from the server. + + :param str receipt_id: the receipt_id + :param float timeout: how long to wait, in seconds + """ + if timeout is not None: + wait_time = timeout / 10.0 + else: + wait_time = None + with self.__receipt_wait_condition: + while self.__receipts.get(receipt_id): + self.__receipt_wait_condition.wait(wait_time) + def __receiver_loop(self): """ Main loop listening for incoming data. @@ -345,6 +380,16 @@ def __receiver_loop(self): if self.__auto_decode: f.body = decode(f.body) self.process_frame(f, frame) + + # wait to finish process messages in progress + with self.__receiver_thread_sending_condition: + while not self.__receiver_thread_sent: + self.__receiver_thread_sending_condition.wait() + + with self.__receiver_thread_exit_condition: + while not self.__receiver_thread_exited and self.is_connected(): + self.__receiver_thread_exit_condition.wait() + except exception.ConnectionClosedException: if self.running: # diff --git a/tests/setup.ini b/tests/setup.ini index a066f81e..b41c9ef6 100644 --- a/tests/setup.ini +++ b/tests/setup.ini @@ -1,5 +1,5 @@ [default] -host = 172.17.0.2 +host = localhost port = 62613 ssl_port = 62614 ssl_expired_port = 62619 @@ -10,7 +10,7 @@ password = password port = 62613 [rabbitmq] -host = 172.17.0.2 +host = localhost port = 61613 user = guest password = guest @@ -24,7 +24,7 @@ host = my.example.com ssl_port = 65001 [artemis] -host = 172.17.0.2 +host = localhost port = 61615 user = testuser password = password \ No newline at end of file diff --git a/tests/test_disconnect_wait.py b/tests/test_disconnect_wait.py new file mode 100644 index 00000000..61e62c1f --- /dev/null +++ b/tests/test_disconnect_wait.py @@ -0,0 +1,110 @@ +import logging + +import stomp +from stomp.listener import TestListener +from .testutils import * + + +class BrokenConnectionListener(TestListener): + def __init__(self, connection=None): + TestListener.__init__(self) + self.connection = connection + self.messages_started = 0 + self.messages_completed = 0 + + def on_error(self, frame): + TestListener.on_error(self, frame) + assert frame.body.startswith("org.apache.activemq.transport.stomp.ProtocolException: Not connected") + + def on_message(self, frame): + TestListener.on_message(self, frame) + self.messages_started += 1 + + if self.connection.is_connected(): + try: + self.connection.ack(frame.headers["message-id"], frame.headers["subscription"]) + self.messages_completed += 1 + except BrokenPipeError: + logging.error("Expected BrokenPipeError") + self.errors += 1 + + +def conn(): + c = stomp.Connection11(get_default_host(), try_loopback_connect=False) + c.set_listener("testlistener", BrokenConnectionListener(c)) + c.connect(get_default_user(), get_default_password(), wait=True) + return c + + +def run_race_condition_situation(conn, wait): + # happens when using ack mode "client-individual" + # some load, eg > 50 messages received at same time (simulated with transaction) + listener = conn.get_listener("testlistener") # type: BrokenConnectionListener + + queuename = "/queue/disconnectmidack-%s" % listener.timestamp + conn.subscribe(destination=queuename, id=1, ack="client-individual") + + trans_id = conn.begin() + for i in range(50): + conn.send(body="test message", destination=queuename, transaction=trans_id) + conn.commit(transaction=trans_id) + + listener.wait_for_message() + conn.disconnect(wait=wait) + + # wait for some messages to start between the time of disconnect start and finish (when the race condition happens) + # needed to check result of listener.errors + time.sleep(0.5) + + # return listener for asserts + return listener + + +def assert_race_condition_disconnect_mid_ack(conn, wait=False): + listener = run_race_condition_situation(conn, wait) + + started = listener.messages_started + logging.debug("messages started %d", started) + + assert listener.connections == 1, "should have received 1 connection acknowledgement" + assert listener.messages == started, f"should have received {started} message" + + # Causes either BrokenPipeError or ProtocolException: Not connected + assert listener.errors >= 1, "should have at least one error" + assert listener.messages_started > listener.messages_completed, f"should have not completed all started" + + +def assert_no_race_condition_disconnect_mid_ack(conn, wait=False): + listener = run_race_condition_situation(conn, wait) + + started = listener.messages_started + logging.debug("T%s : messages started %d", started, threading.get_native_id()) + + assert listener.connections == 1, "should have received 1 connection acknowledgement" + assert listener.messages == started, f"should have received {started} message" + + assert listener.errors == 0, "should not have errors" + assert listener.messages_started == listener.messages_completed, f"should have completed all started" + + +def test_assert_race_condition_in_disconnect_mid_ack(): + found_race_condition = False + retries_until_race_condition = 0 + while not found_race_condition: + try: + assert_race_condition_disconnect_mid_ack(conn()) + found_race_condition = True + except AssertionError as e: + retries_until_race_condition += 1 + continue + + assert found_race_condition is True + # might occur at first try, might take 50 retries + logging.warning("Tries until race condition: %d", retries_until_race_condition) + + +def test_assert_fixed_race_condition_in_disconnect_mid_ack(): + # same test case but asserts no error + # you can increase forever, it always passes + for n in range(100): + assert_no_race_condition_disconnect_mid_ack(conn(), wait=True)