Skip to content

Commit

Permalink
Test removing handlers in ServicerBase
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Aug 16, 2022
1 parent 71b2392 commit 281a4ff
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 66 deletions.
6 changes: 2 additions & 4 deletions hivemind/averaging/averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
from hivemind.compression import CompressionBase, CompressionInfo, NoCompression, deserialize_torch_tensor
from hivemind.dht import DHT, DHTID
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
from hivemind.p2p import P2P, P2PContext, P2PDaemonError, P2PHandlerError, PeerID, ServicerBase
from hivemind.proto import averaging_pb2
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
from hivemind.utils.asyncio import (
Expand Down Expand Up @@ -469,8 +468,7 @@ async def find_peers_or_notify_cancel():
asyncio.CancelledError,
asyncio.InvalidStateError,
P2PHandlerError,
DispatchFailure,
ControlFailure,
P2PDaemonError,
) as e:
if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
if not step.cancelled():
Expand Down
5 changes: 2 additions & 3 deletions hivemind/averaging/matchmaking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from hivemind.averaging.group_info import GroupInfo
from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
from hivemind.dht import DHT, DHTID
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
from hivemind.p2p import P2P, P2PContext, P2PDaemonError, P2PHandlerError, PeerID, ServicerBase
from hivemind.proto import averaging_pb2
from hivemind.utils import DHTExpiration, TimedStorage, get_dht_time, get_logger, timed_storage
from hivemind.utils.asyncio import anext, cancel_and_wait
Expand Down Expand Up @@ -239,7 +238,7 @@ async def _request_join_group(self, leader: PeerID) -> Optional[GroupInfo]:
except asyncio.TimeoutError:
logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
return None
except (P2PHandlerError, ControlFailure, DispatchFailure, StopAsyncIteration) as e:
except (P2PDaemonError, P2PHandlerError, StopAsyncIteration) as e:
logger.debug(f"{self} - failed to request potential leader {leader}:", exc_info=True)
return None

Expand Down
4 changes: 2 additions & 2 deletions hivemind/p2p/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PDaemonError, P2PHandlerError
from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
from hivemind.p2p.p2p_daemon import P2P, P2PContext
from hivemind.p2p.p2p_daemon_bindings import P2PDaemonError, P2PHandlerError, PeerID, PeerInfo
from hivemind.p2p.servicer import ServicerBase, StubBase
2 changes: 2 additions & 0 deletions hivemind/p2p/p2p_daemon_bindings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
from hivemind.p2p.p2p_daemon_bindings.utils import P2PDaemonError, P2PHandlerError
14 changes: 1 addition & 13 deletions hivemind/p2p/p2p_daemon_bindings/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from multiaddr import Multiaddr, protocols

from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
from hivemind.p2p.p2p_daemon_bindings.utils import DispatchFailure, raise_if_failed, read_pbmsg_safe, write_pbmsg
from hivemind.p2p.p2p_daemon_bindings.utils import P2PDaemonError, P2PHandlerError, DispatchFailure, raise_if_failed, read_pbmsg_safe, write_pbmsg
from hivemind.proto import p2pd_pb2 as p2pd_pb
from hivemind.utils.logging import get_logger

Expand Down Expand Up @@ -416,15 +416,3 @@ async def remove_stream_handler(self, proto: str) -> None:
raise_if_failed(resp)

del self.handlers[proto]


class P2PHandlerError(Exception):
"""
Raised if remote handled a request with an exception
"""


class P2PDaemonError(Exception):
"""
Raised if daemon failed to handle request
"""
16 changes: 14 additions & 2 deletions hivemind/p2p/p2p_daemon_bindings/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,23 @@
DEFAULT_MAX_BITS: int = 64


class ControlFailure(Exception):
class P2PHandlerError(Exception):
"""
Raised if remote handled a request with an exception
"""


class P2PDaemonError(Exception):
"""
Raised if daemon failed to handle request
"""


class ControlFailure(P2PDaemonError):
pass


class DispatchFailure(Exception):
class DispatchFailure(P2PDaemonError):
pass


Expand Down
121 changes: 79 additions & 42 deletions tests/test_p2p_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from hivemind.p2p import P2P, P2PContext, ServicerBase
from hivemind.p2p import P2P, P2PContext, P2PDaemonError, ServicerBase
from hivemind.proto import test_pb2
from hivemind.utils.asyncio import anext

Expand All @@ -17,35 +17,37 @@ async def server_client():
await asyncio.gather(server.shutdown(), client.shutdown())


class UnaryUnaryServicer(ServicerBase):
async def rpc_square(self, request: test_pb2.TestRequest, _context: P2PContext) -> test_pb2.TestResponse:
return test_pb2.TestResponse(number=request.number**2)


@pytest.mark.asyncio
async def test_unary_unary(server_client):
class ExampleServicer(ServicerBase):
async def rpc_square(self, request: test_pb2.TestRequest, _context: P2PContext) -> test_pb2.TestResponse:
return test_pb2.TestResponse(number=request.number**2)

server, client = server_client
servicer = ExampleServicer()
servicer = UnaryUnaryServicer()
await servicer.add_p2p_handlers(server)
stub = ExampleServicer.get_stub(client, server.peer_id)
stub = UnaryUnaryServicer.get_stub(client, server.peer_id)

assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)


class StreamUnaryServicer(ServicerBase):
async def rpc_sum(
self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
) -> test_pb2.TestResponse:
result = 0
async for item in stream:
result += item.number
return test_pb2.TestResponse(number=result)


@pytest.mark.asyncio
async def test_stream_unary(server_client):
class ExampleServicer(ServicerBase):
async def rpc_sum(
self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
) -> test_pb2.TestResponse:
result = 0
async for item in stream:
result += item.number
return test_pb2.TestResponse(number=result)

server, client = server_client
servicer = ExampleServicer()
servicer = StreamUnaryServicer()
await servicer.add_p2p_handlers(server)
stub = ExampleServicer.get_stub(client, server.peer_id)
stub = StreamUnaryServicer.get_stub(client, server.peer_id)

async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
for i in range(10):
Expand All @@ -54,42 +56,40 @@ async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
assert await stub.rpc_sum(generate_requests()) == test_pb2.TestResponse(number=45)


class UnaryStreamServicer(ServicerBase):
async def rpc_count(
self, request: test_pb2.TestRequest, _context: P2PContext
) -> AsyncIterator[test_pb2.TestResponse]:
for i in range(request.number):
yield test_pb2.TestResponse(number=i)


@pytest.mark.asyncio
async def test_unary_stream(server_client):
class ExampleServicer(ServicerBase):
async def rpc_count(
self, request: test_pb2.TestRequest, _context: P2PContext
) -> AsyncIterator[test_pb2.TestResponse]:
for i in range(request.number):
yield test_pb2.TestResponse(number=i)

server, client = server_client
servicer = ExampleServicer()
servicer = UnaryStreamServicer()
await servicer.add_p2p_handlers(server)
stub = ExampleServicer.get_stub(client, server.peer_id)
stub = UnaryStreamServicer.get_stub(client, server.peer_id)

stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
i = 0
async for item in stream:
assert item == test_pb2.TestResponse(number=i)
i += 1
assert i == 10
assert [item.number async for item in stream] == list(range(10))


class StreamStreamServicer(ServicerBase):
async def rpc_powers(
self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
) -> AsyncIterator[test_pb2.TestResponse]:
async for item in stream:
yield test_pb2.TestResponse(number=item.number**2)
yield test_pb2.TestResponse(number=item.number**3)


@pytest.mark.asyncio
async def test_stream_stream(server_client):
class ExampleServicer(ServicerBase):
async def rpc_powers(
self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
) -> AsyncIterator[test_pb2.TestResponse]:
async for item in stream:
yield test_pb2.TestResponse(number=item.number**2)
yield test_pb2.TestResponse(number=item.number**3)

server, client = server_client
servicer = ExampleServicer()
servicer = StreamStreamServicer()
await servicer.add_p2p_handlers(server)
stub = ExampleServicer.get_stub(client, server.peer_id)
stub = StreamStreamServicer.get_stub(client, server.peer_id)

async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
for i in range(10):
Expand Down Expand Up @@ -153,3 +153,40 @@ async def rpc_wait(

await asyncio.sleep(0.25)
assert handler_cancelled


@pytest.mark.asyncio
async def test_removing_unary_handlers(server_client):
server1, client = server_client
server2 = await P2P.replicate(server1.daemon_listen_maddr)
servicer = UnaryUnaryServicer()
stub = UnaryUnaryServicer.get_stub(client, server1.peer_id)

for server in [server1, server2, server1]:
await servicer.add_p2p_handlers(server)
assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)

await servicer.remove_p2p_handlers(server)
with pytest.raises(P2PDaemonError):
await stub.rpc_square(test_pb2.TestRequest(number=10))

await asyncio.gather(server2.shutdown())


@pytest.mark.asyncio
async def test_removing_stream_handlers(server_client):
server1, client = server_client
server2 = await P2P.replicate(server1.daemon_listen_maddr)
servicer = UnaryStreamServicer()
stub = UnaryStreamServicer.get_stub(client, server1.peer_id)

for server in [server1, server2, server1]:
await servicer.add_p2p_handlers(server)
stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
assert [item.number async for item in stream] == list(range(10))

await servicer.remove_p2p_handlers(server)
with pytest.raises(P2PDaemonError):
await stub.rpc_count(test_pb2.TestRequest(number=10))

await asyncio.gather(server2.shutdown())

0 comments on commit 281a4ff

Please sign in to comment.