Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#204 P2P replica mode #205

Merged
merged 2 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions hivemind/p2p/p2p_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ async def create(cls, *args, quic=1, tls=1, conn_manager=1, dht_client=1,
break
return self

@classmethod
async def replica(cls, daemon_listen_port: int, host_port: int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: rename this method into a verb (replicate) or context (from_running_daemon)?

self = cls()
# There is no child under control
# Use external already running p2pd
self._child = None
self._assign_daemon_ports(host_port, daemon_listen_port)
self._client_listen_port = find_open_port()
self._client = p2pclient.Client(
Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'),
Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}'))
await self._identify_client(0)
return self

def _initialize(self, proc_args: tp.List[str]) -> None:
proc_args = copy.deepcopy(proc_args)
proc_args.extend(self._make_process_args(
Expand Down
115 changes: 97 additions & 18 deletions tests/test_p2p_daemon.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import multiprocessing as mp
import subprocess
from functools import partial

from hivemind.p2p.p2p_daemon_bindings.datastructures import ID

Expand All @@ -27,6 +28,10 @@ def is_process_running(pid: int) -> bool:
return subprocess.check_output(cmd, shell=True).decode('utf-8').strip() == RUNNING


async def replicate_if_needed(p2p: P2P, replicate: bool):
return await P2P.replica(p2p._daemon_listen_port, p2p._host_port) if replicate else p2p


@pytest.mark.asyncio
async def test_daemon_killed_on_del():
p2p_daemon = await P2P.create()
Expand All @@ -38,6 +43,21 @@ async def test_daemon_killed_on_del():
assert not is_process_running(child_pid)


@pytest.mark.asyncio
async def test_daemon_replica_does_not_affect_primary():
p2p_daemon = await P2P.create()
p2p_replica = await P2P.replica(p2p_daemon._daemon_listen_port, p2p_daemon._host_port)

child_pid = p2p_daemon._child.pid
assert is_process_running(child_pid)

p2p_replica.__del__()
assert is_process_running(child_pid)

p2p_daemon.__del__()
assert not is_process_running(child_pid)


def handle_square(x):
return x ** 2

Expand All @@ -50,10 +70,15 @@ def handle_add(args):


@pytest.mark.parametrize(
'should_cancel', [True, False]
'should_cancel,replicate', [
(True, False),
(True, True),
(False, False),
(False, True),
]
)
@pytest.mark.asyncio
async def test_call_unary_handler(should_cancel, handle_name="handle"):
async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"):
handler_cancelled = False

async def ping_handler(request, context):
Expand All @@ -67,14 +92,16 @@ async def ping_handler(request, context):
node_id=context.ours_id.encode(), rpc_port=context.ours_port),
sender_endpoint=context.handle_name, available=True)

server = await P2P.create()
server_pid = server._child.pid
server_primary = await P2P.create()
server = await replicate_if_needed(server_primary, replicate)
server_pid = server_primary._child.pid
await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
dht_pb2.PingResponse)
assert is_process_running(server_pid)

client = await P2P.create()
client_pid = client._child.pid
client_primary = await P2P.create()
client = await replicate_if_needed(client_primary, replicate)
client_pid = client_primary._child.pid
assert is_process_running(client_pid)

ping_request = dht_pb2.PingRequest(
Expand All @@ -100,10 +127,10 @@ async def ping_handler(request, context):
assert not handler_cancelled

await server.stop_listening()
server.__del__()
server_primary.__del__()
assert not is_process_running(server_pid)

client.__del__()
client_primary.__del__()
assert not is_process_running(client_pid)


Expand Down Expand Up @@ -131,7 +158,6 @@ async def test_call_peer_single_process(test_input, handle, handler_name="handle
result = await client.call_peer_handler(server.id, handler_name, test_input)
assert result == handle(test_input)

await server.stop_listening()
server.__del__()
assert not is_process_running(server_pid)

Expand Down Expand Up @@ -188,30 +214,83 @@ async def test_call_peer_different_processes():


@pytest.mark.parametrize(
"test_input,handle",
"test_input,handle,replicate",
[
pytest.param(np.random.randn(2, 3), handle_square, id="square"),
pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, id="add"),
pytest.param(np.random.randn(2, 3), handle_square, False, id="square_primary"),
pytest.param(np.random.randn(2, 3), handle_square, True, id="square_replica"),
pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, False, id="add_primary"),
pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, True, id="add_replica"),
]
)
@pytest.mark.asyncio
async def test_call_peer_numpy(test_input, handle, handler_name="handle"):
server = await P2P.create()
async def test_call_peer_numpy(test_input, handle, replicate, handler_name="handle"):
server_primary = await P2P.create()
server = await replicate_if_needed(server_primary, replicate)
await server.add_stream_handler(handler_name, handle)
client = await P2P.create()
client_primary = await P2P.create()
client = await replicate_if_needed(client_primary, replicate)

await asyncio.sleep(1)
result = await client.call_peer_handler(server.id, handler_name, test_input)
assert np.allclose(result, handle(test_input))


@pytest.mark.parametrize(
"replicate",
[
pytest.param(False, id="primary"),
pytest.param(True, id="replica"),
]
)
@pytest.mark.asyncio
async def test_call_peer_error(handler_name="handle"):
server = await P2P.create()
async def test_call_peer_error(replicate, handler_name="handle"):
server_primary = await P2P.create()
server = await replicate_if_needed(server_primary, replicate)
await server.add_stream_handler(handler_name, handle_add)
client = await P2P.create()
client_primary = await P2P.create()
client = await replicate_if_needed(client_primary, replicate)

await asyncio.sleep(1)
result = await client.call_peer_handler(server.id, handler_name,
[np.zeros((2, 3)), np.zeros((3, 2))])
assert type(result) == ValueError


@pytest.mark.asyncio
async def test_handlers_on_different_replicas(handler_name="handle"):
def handler(arg, key):
return key

server_primary = await P2P.create()
server_id = server_primary.id
await server_primary.add_stream_handler(handler_name, partial(handler, key="primary"))

server_replica1 = await replicate_if_needed(server_primary, True)
await server_replica1.add_stream_handler(handler_name + "1", partial(handler, key="replica1"))

server_replica2 = await replicate_if_needed(server_primary, True)
await server_replica2.add_stream_handler(handler_name + "2", partial(handler, key="replica2"))

client = await P2P.create()
await asyncio.sleep(1)
result = await client.call_peer_handler(server_id, handler_name, "")
assert result == "primary"

result = await client.call_peer_handler(server_id, handler_name + "1", "")
assert result == "replica1"

result = await client.call_peer_handler(server_id, handler_name + "2", "")
assert result == "replica2"

await server_replica1.stop_listening()
await server_replica2.stop_listening()

# Primary does not handle replicas protocols
with pytest.raises(P2P.IncompleteRead):
await client.call_peer_handler(server_id, handler_name + "1", "")
with pytest.raises(P2P.IncompleteRead):
await client.call_peer_handler(server_id, handler_name + "2", "")

await server_primary.stop_listening()
server_primary.__del__()
client.__del__()