Skip to content

Commit

Permalink
fix race condition between ack/nack and disconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
carantunes committed Jul 22, 2022
1 parent 3558745 commit 34ecd47
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 26 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ PYTHON_VERSION_MAJOR:=$(shell $(PYTHON) -c "import sys;print(sys.version_info[0]
PLATFORM := $(shell uname)
VERSION :=$(shell poetry version | sed 's/stomp.py\s*//g' | sed 's/\./, /g')
SHELL=/bin/bash
ARTEMIS_VERSION=2.22.0
ARTEMIS_VERSION=2.23.1
TEST_CMD := $(shell podman network exists stomptest &> /dev/null && echo "podman unshare --rootless-netns poetry" || echo "poetry")

all: test install
Expand Down Expand Up @@ -77,7 +77,7 @@ docker-image: docker/tmp/activemq-artemis-bin.tar.gz ssl-setup


run-docker:
docker run --add-host="my.example.com:127.0.0.1" --add-host="my.example.org:127.0.0.1" --add-host="my.example.net:127.0.0.1" -d -p 61613:61613 -p 62613:62613 -p 62614:62614 -p 63613:63613 -p 64613:64613 --name stomppy -it stomppy
docker run --add-host="my.example.com:127.0.0.1" --add-host="my.example.org:127.0.0.1" --add-host="my.example.net:127.0.0.1" -d -p 61613:61613 -p 62613:62613 -p 62614:62614 -p 63613:63613 -p 64613:64613 -p 8161:8161 --name stomppy -it stomppy
docker ps
docker exec -it stomppy /bin/sh -c "/etc/init.d/activemq start"
docker exec -it stomppy /bin/sh -c "/etc/init.d/stompserver start"
Expand Down
6 changes: 5 additions & 1 deletion stomp/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,18 @@ def connect(self, *args, **kwargs):
self.transport.start()
Protocol11.connect(self, *args, **kwargs)

def disconnect(self, receipt=None, headers=None, **keyword_headers):
def disconnect(self, receipt=None, headers=None, wait=False, **keyword_headers):
"""
Call the protocol disconnection, and then stop the transport itself.
:param str receipt: the receipt to use with the disconnect
:param dict headers: a map of any additional headers to send with the disconnection
:param bool wait: wait for the started messages to finish ack/nack before disconnection
:param keyword_headers: any additional headers to send with the disconnection
"""
if wait:
self.transport.begin_stop()

Protocol11.disconnect(self, receipt, headers, **keyword_headers)
if receipt is not None:
self.transport.stop()
Expand Down
85 changes: 65 additions & 20 deletions stomp/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class BaseTransport(stomp.listener.Publisher):
__content_length_re = re.compile(b"^content-length[:]\\s*(?P<value>[0-9]+)", re.MULTILINE)

def __init__(self, auto_decode=True, encoding="utf-8", is_eol_fc=is_eol_default):
self.__receiver_thread_sending_condition = threading.Condition()
self.__receiver_thread_sent = True
self.__recvbuf = b""
self.listeners = {}
self.running = False
Expand All @@ -79,7 +81,7 @@ def __init__(self, auto_decode=True, encoding="utf-8", is_eol_fc=is_eol_default)
self.__listeners_change_condition = threading.Condition()
self.__receiver_thread_exit_condition = threading.Condition()
self.__receiver_thread_exited = False
self.__send_wait_condition = threading.Condition()
self.__receipt_wait_condition = threading.Condition()
self.__connect_wait_condition = threading.Condition()
self.__auto_decode = auto_decode
self.__encoding = encoding
Expand Down Expand Up @@ -112,6 +114,13 @@ def start(self):
logging.info("Created thread %s using func %s", receiver_thread, self.create_thread_fc)
self.notify("connecting")

def begin_stop(self):
"""
Begin stop of the connection. Stops reading new messages but keep thread to finish ack/nack of messages.
"""
# emit stop reading new messages
self.running = False

def stop(self):
"""
Stop the connection. Performs a clean shutdown by waiting for the
Expand Down Expand Up @@ -206,9 +215,10 @@ def notify(self, frame_type, frame=None):
# logic for wait-on-receipt notification
receipt = frame.headers["receipt-id"]
receipt_value = self.__receipts.get(receipt)
with self.__send_wait_condition:

with self.__receipt_wait_condition:
self.set_receipt(receipt, None)
self.__send_wait_condition.notify()
self.__receipt_wait_condition.notifyAll()

if receipt_value == CMD_DISCONNECT:
self.set_connected(False)
Expand All @@ -232,7 +242,7 @@ def notify(self, frame_type, frame=None):
if not notify_func:
logging.debug("listener %s has no method on_%s", listener, frame_type)
continue
if frame_type in ("heartbeat", "disconnected"):
if frame_type in ("disconnecting", "heartbeat", "disconnected"):
notify_func()
continue
if frame_type == "connecting":
Expand All @@ -252,26 +262,36 @@ def transmit(self, frame):
:param Frame frame: the Frame object to transmit
"""
with self.__listeners_change_condition:
listeners = sorted(self.listeners.items())
with self.__receiver_thread_sending_condition:
self.__receiver_thread_sent = False
self.__receiver_thread_sending_condition.notify_all()

for (_, listener) in listeners:
try:
listener.on_send(frame)
except AttributeError:
continue
try:
with self.__listeners_change_condition:
listeners = sorted(self.listeners.items())

if frame.cmd == CMD_DISCONNECT and HDR_RECEIPT in frame.headers:
self.__disconnect_receipt = frame.headers[HDR_RECEIPT]
for (_, listener) in listeners:
try:
listener.on_send(frame)
except AttributeError:
continue

lines = convert_frame(frame)
packed_frame = pack(lines)
if frame.cmd == CMD_DISCONNECT and HDR_RECEIPT in frame.headers:
self.__disconnect_receipt = frame.headers[HDR_RECEIPT]

if logging.isEnabledFor(logging.DEBUG):
logging.debug("Sending frame: %s", clean_lines(lines))
else:
logging.info("Sending frame: %r", frame.cmd or "heartbeat")
self.send(packed_frame)
lines = convert_frame(frame)
packed_frame = pack(lines)

if logging.isEnabledFor(logging.DEBUG):
logging.debug("Sending frame: %s", clean_lines(lines))
else:
logging.info("Sending frame: %r", frame.cmd or "heartbeat")
self.send(packed_frame)

finally:
with self.__receiver_thread_sending_condition:
self.__receiver_thread_sent = True
self.__receiver_thread_sending_condition.notify_all()

def send(self, encoded_frame):
"""
Expand Down Expand Up @@ -323,6 +343,21 @@ def wait_for_connection(self, timeout=None):
if not self.running or not self.is_connected():
raise exception.ConnectFailedException()

def wait_for_receipt(self, receipt_id, timeout=None):
"""
Wait until we've received a receipt from the server.
:param str receipt_id: the receipt_id
:param float timeout: how long to wait, in seconds
"""
if timeout is not None:
wait_time = timeout / 10.0
else:
wait_time = None
with self.__receipt_wait_condition:
while self.__receipts.get(receipt_id):
self.__receipt_wait_condition.wait(wait_time)

def __receiver_loop(self):
"""
Main loop listening for incoming data.
Expand All @@ -345,6 +380,16 @@ def __receiver_loop(self):
if self.__auto_decode:
f.body = decode(f.body)
self.process_frame(f, frame)

# wait to finish process messages in progress
with self.__receiver_thread_sending_condition:
while not self.__receiver_thread_sent:
self.__receiver_thread_sending_condition.wait()

with self.__receiver_thread_exit_condition:
while not self.__receiver_thread_exited and self.is_connected():
self.__receiver_thread_exit_condition.wait()

except exception.ConnectionClosedException:
if self.running:
#
Expand Down
6 changes: 3 additions & 3 deletions tests/setup.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[default]
host = 172.17.0.2
host = localhost
port = 62613
ssl_port = 62614
ssl_expired_port = 62619
Expand All @@ -10,7 +10,7 @@ password = password
port = 62613

[rabbitmq]
host = 172.17.0.2
host = localhost
port = 61613
user = guest
password = guest
Expand All @@ -24,7 +24,7 @@ host = my.example.com
ssl_port = 65001

[artemis]
host = 172.17.0.2
host = localhost
port = 61615
user = testuser
password = password
110 changes: 110 additions & 0 deletions tests/test_disconnect_wait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import logging

import stomp
from stomp.listener import TestListener
from .testutils import *


class BrokenConnectionListener(TestListener):
def __init__(self, connection=None):
TestListener.__init__(self)
self.connection = connection
self.messages_started = 0
self.messages_completed = 0

def on_error(self, frame):
TestListener.on_error(self, frame)
assert frame.body.startswith("org.apache.activemq.transport.stomp.ProtocolException: Not connected")

def on_message(self, frame):
TestListener.on_message(self, frame)
self.messages_started += 1

if self.connection.is_connected():
try:
self.connection.ack(frame.headers["message-id"], frame.headers["subscription"])
self.messages_completed += 1
except BrokenPipeError:
logging.error("Expected BrokenPipeError")
self.errors += 1


def conn():
c = stomp.Connection11(get_default_host(), try_loopback_connect=False)
c.set_listener("testlistener", BrokenConnectionListener(c))
c.connect(get_default_user(), get_default_password(), wait=True)
return c


def run_race_condition_situation(conn, wait):
# happens when using ack mode "client-individual"
# some load, eg > 50 messages received at same time (simulated with transaction)
listener = conn.get_listener("testlistener") # type: BrokenConnectionListener

queuename = "/queue/disconnectmidack-%s" % listener.timestamp
conn.subscribe(destination=queuename, id=1, ack="client-individual")

trans_id = conn.begin()
for i in range(50):
conn.send(body="test message", destination=queuename, transaction=trans_id)
conn.commit(transaction=trans_id)

listener.wait_for_message()
conn.disconnect(wait=wait)

# wait for some messages to start between the time of disconnect start and finish (when the race condition happens)
# needed to check result of listener.errors
time.sleep(0.5)

# return listener for asserts
return listener


def assert_race_condition_disconnect_mid_ack(conn, wait=False):
listener = run_race_condition_situation(conn, wait)

started = listener.messages_started
logging.debug("messages started %d", started)

assert listener.connections == 1, "should have received 1 connection acknowledgement"
assert listener.messages == started, f"should have received {started} message"

# Causes either BrokenPipeError or ProtocolException: Not connected
assert listener.errors >= 1, "should have at least one error"
assert listener.messages_started > listener.messages_completed, f"should have not completed all started"


def assert_no_race_condition_disconnect_mid_ack(conn, wait=False):
listener = run_race_condition_situation(conn, wait)

started = listener.messages_started
logging.debug("T%s : messages started %d", started, threading.get_native_id())

assert listener.connections == 1, "should have received 1 connection acknowledgement"
assert listener.messages == started, f"should have received {started} message"

assert listener.errors == 0, "should not have errors"
assert listener.messages_started == listener.messages_completed, f"should have completed all started"


def test_assert_race_condition_in_disconnect_mid_ack():
found_race_condition = False
retries_until_race_condition = 0
while not found_race_condition:
try:
assert_race_condition_disconnect_mid_ack(conn())
found_race_condition = True
except AssertionError as e:
retries_until_race_condition += 1
continue

assert found_race_condition is True
# might occur at first try, might take 50 retries
logging.warning("Tries until race condition: %d", retries_until_race_condition)


def test_assert_fixed_race_condition_in_disconnect_mid_ack():
# same test case but asserts no error
# you can increase forever, it always passes
for n in range(100):
assert_no_race_condition_disconnect_mid_ack(conn(), wait=True)

0 comments on commit 34ecd47

Please sign in to comment.