diff --git a/rosbridge_library/src/rosbridge_library/capabilities/advertise_service.py b/rosbridge_library/src/rosbridge_library/capabilities/advertise_service.py index 5ad05f865..a1db57ed5 100644 --- a/rosbridge_library/src/rosbridge_library/capabilities/advertise_service.py +++ b/rosbridge_library/src/rosbridge_library/capabilities/advertise_service.py @@ -45,9 +45,10 @@ async def handle_request(self, req, res): } self.protocol.send(request_message) - res = await future - del self.request_futures[request_id] - return res + try: + return await future + finally: + del self.request_futures[request_id] def handle_response(self, request_id, res): """ @@ -75,6 +76,8 @@ def graceful_shutdown(self): f"Service {self.service_name} was unadvertised with a service call in progress, " f"aborting service calls with request IDs {incomplete_ids}", ) + for future in self.request_futures.values(): + future.set_exception(RuntimeError(f"Service {self.service_name} was unadvertised")) self.protocol.node_handle.destroy_service(self.service_handle) @@ -128,9 +131,7 @@ def advertise_service(self, message): self.protocol.log( "warn", "Duplicate service advertised. Overwriting %s." % service_name ) - self.protocol.external_service_list[service_name].service_handle.shutdown( - "Duplicate advertiser." - ) + self.protocol.external_service_list[service_name].graceful_shutdown() del self.protocol.external_service_list[service_name] # setup and store the service information diff --git a/rosbridge_server/test/websocket/advertise_service.test.py b/rosbridge_server/test/websocket/advertise_service.test.py index 5286ba755..6097dc3f3 100644 --- a/rosbridge_server/test/websocket/advertise_service.test.py +++ b/rosbridge_server/test/websocket/advertise_service.test.py @@ -19,7 +19,8 @@ class TestAdvertiseService(unittest.TestCase): @websocket_test - async def test_two_concurrent_calls(self, node: Node, ws_client): + async def test_two_concurrent_calls(self, node: Node, make_client): + ws_client = await make_client() ws_client.sendJson( { "op": "advertise_service", diff --git a/rosbridge_server/test/websocket/advertise_service_duplicate.test.py b/rosbridge_server/test/websocket/advertise_service_duplicate.test.py new file mode 100644 index 000000000..466e44350 --- /dev/null +++ b/rosbridge_server/test/websocket/advertise_service_duplicate.test.py @@ -0,0 +1,99 @@ +import os +import sys +import unittest + +from rclpy.node import Node +from std_srvs.srv import SetBool +from twisted.python import log + +sys.path.append(os.path.dirname(__file__)) # enable importing from common.py in this directory + +import common # noqa: E402 +from common import expect_messages, sleep, websocket_test # noqa: E402 + +log.startLogging(sys.stderr) + +generate_test_description = common.generate_test_description + + +class TestAdvertiseService(unittest.TestCase): + @websocket_test + async def test_double_advertise(self, node: Node, make_client): + ws_client1 = await make_client() + ws_client1.sendJson( + { + "op": "advertise_service", + "type": "std_srvs/SetBool", + "service": "/test_service", + } + ) + client = node.create_client(SetBool, "/test_service") + client.wait_for_service() + + requests1_future, ws_client1.message_handler = expect_messages( + 1, "WebSocket 1", node.get_logger() + ) + requests1_future.add_done_callback(lambda _: node.executor.wake()) + + client.call_async(SetBool.Request(data=True)) + + requests1 = await requests1_future + self.assertEqual( + requests1, + [ + { + "op": "call_service", + "service": "/test_service", + "id": "service_request:/test_service:1", + "args": {"data": True}, + } + ], + ) + + ws_client1.sendClose() + + ws_client2 = await make_client() + ws_client2.sendJson( + { + "op": "advertise_service", + "type": "std_srvs/SetBool", + "service": "/test_service", + } + ) + + # wait for the server to handle the new advertisement + await sleep(node, 1) + + requests2_future, ws_client2.message_handler = expect_messages( + 1, "WebSocket 2", node.get_logger() + ) + requests2_future.add_done_callback(lambda _: node.executor.wake()) + + response2_future = client.call_async(SetBool.Request(data=False)) + + requests2 = await requests2_future + self.assertEqual( + requests2, + [ + { + "op": "call_service", + "id": "service_request:/test_service:1", + "service": "/test_service", + "args": {"data": False}, + } + ], + ) + + ws_client2.sendJson( + { + "op": "service_response", + "service": "/test_service", + "values": {"success": True, "message": "Hello world 2"}, + "id": "service_request:/test_service:1", + "result": True, + } + ) + + self.assertEqual( + await response2_future, SetBool.Response(success=True, message="Hello world 2") + ) diff --git a/rosbridge_server/test/websocket/call_service.test.py b/rosbridge_server/test/websocket/call_service.test.py index 993063f9e..549b4cbb3 100644 --- a/rosbridge_server/test/websocket/call_service.test.py +++ b/rosbridge_server/test/websocket/call_service.test.py @@ -19,7 +19,7 @@ class TestCallService(unittest.TestCase): @websocket_test - async def test_one_call(self, node: Node, ws_client): + async def test_one_call(self, node: Node, make_client): def service_cb(req, res): self.assertTrue(req.data) res.success = True @@ -28,6 +28,7 @@ def service_cb(req, res): service = node.create_service(SetBool, "/test_service", service_cb) + ws_client = await make_client() responses_future, ws_client.message_handler = expect_messages( 1, "WebSocket", node.get_logger() ) diff --git a/rosbridge_server/test/websocket/common.py b/rosbridge_server/test/websocket/common.py index 4949ab3dc..e55a85b36 100644 --- a/rosbridge_server/test/websocket/common.py +++ b/rosbridge_server/test/websocket/common.py @@ -103,7 +103,8 @@ def connect(): def run_websocket_test( - node_name: str, test_fn: Callable[[Node, TestClientProtocol], Awaitable[None]] + node_name: str, + test_fn: Callable[[Node, Callable[[], Awaitable[TestClientProtocol]]], Awaitable[None]], ): context = rclpy.Context() rclpy.init(context=context) @@ -112,8 +113,7 @@ def run_websocket_test( executor.add_node(node) async def task(): - ws_client = await connect_to_server(node) - await test_fn(node, ws_client) + await test_fn(node, lambda: connect_to_server(node)) reactor.callFromThread(reactor.stop) future = executor.create_task(task) @@ -142,7 +142,8 @@ def callback(): def websocket_test(test_fn): """ - Decorator for tests which use a ROS node and WebSocket server and client + Decorator for tests which use a ROS node and WebSocket server and client. + Multiple tests per file are not supported because the Twisted reactor cannot be run multiple times. """ @functools.wraps(test_fn) diff --git a/rosbridge_server/test/websocket/smoke.test.py b/rosbridge_server/test/websocket/smoke.test.py index 1e169d19b..2447ed424 100644 --- a/rosbridge_server/test/websocket/smoke.test.py +++ b/rosbridge_server/test/websocket/smoke.test.py @@ -19,7 +19,8 @@ class TestWebsocketSmoke(unittest.TestCase): @websocket_test - async def test_smoke(self, node: Node, ws_client): + async def test_smoke(self, node: Node, make_client): + ws_client = await make_client() # For consistency, the number of messages must not exceed the the protocol # Subscriber queue_size. NUM_MSGS = 10