Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert DHT to libp2p backend #296

Merged
merged 86 commits into from
Jul 10, 2021
Merged
Show file tree
Hide file tree
Changes from 70 commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
4750880
Make DHTProtocol and DHTNode work over P2P
borzunov Jun 29, 2021
a822899
Run daemon in quiet mode, improve interfaces and typing in hivemind.p…
borzunov Jun 30, 2021
ebbfe30
Use Unix sockets for daemon communication
borzunov Jun 30, 2021
ad17dbb
Get rid of find_open_port() in hivemind.p2p.P2P
borzunov Jun 30, 2021
4431d86
Rename _wait_for_client to _ping_client
borzunov Jun 30, 2021
bd398d1
Add test for QUIC transport that is now disabled by default
borzunov Jun 30, 2021
75e6ac7
Make public P2P.list_peers() method
borzunov Jun 30, 2021
679be89
Revert minor change to minimize diff
borzunov Jun 30, 2021
2ff6e2c
Test error for wrong p2pd args
borzunov Jun 30, 2021
1104e23
Add more logging, use .shutdown() in async code
borzunov Jun 30, 2021
64387cb
Fix calling .shutdown()
borzunov Jun 30, 2021
1041751
Fix graceful shutdowns in hivemind.p2p.P2P tests
borzunov Jun 30, 2021
f2e49b3
Stop ping retries if p2pd dies
borzunov Jul 1, 2021
ef92722
Speed up creation of DHT node swarm
borzunov Jul 1, 2021
3fdb80c
Cache result of .identify_maddrs()
borzunov Jul 1, 2021
a11b0d4
Merge branch 'master' into dht-over-p2p
borzunov Jul 1, 2021
6bbe58e
Use same peers for P2P and DHTNode in test_dht_node.py::test_dht_node
borzunov Jul 1, 2021
2cf54be
Accept multiaddrs as DHTNode initial_peers as well
borzunov Jul 1, 2021
759d0dc
Allow implicit P2P creation inside DHTNode
borzunov Jul 1, 2021
ab38ce8
Finish converting test_dht_node.py to DHT over P2P
borzunov Jul 1, 2021
47aadba
Extract launch_star_shaped_swarm() function
borzunov Jul 1, 2021
72f4258
Convert test_dht_crypto.py
borzunov Jul 1, 2021
6f65bde
Fix test_dht_schema.py partially
borzunov Jul 1, 2021
f73ca6c
Shorten code in DHTProtocol
borzunov Jul 1, 2021
6bacf52
Convert hivemind.dht.DHT constructor
borzunov Jul 1, 2021
7a3f8d7
Merge remote-tracking branch 'origin/master' into dht-over-p2p
borzunov Jul 1, 2021
5acd1dd
Finish converting hivemind.dht.DHT
borzunov Jul 1, 2021
8fc19d4
Rename identify_maddrs() -> get_visible_maddrs()
borzunov Jul 1, 2021
f148e44
Convert test_dht_experts.py
borzunov Jul 1, 2021
0a44f07
Convert benchmark_dht.py, fix bugs
borzunov Jul 2, 2021
c4918f9
Fix DHTNode.create docs
borzunov Jul 2, 2021
1da829e
Make DHTNode.need_manage_p2p private
borzunov Jul 2, 2021
8d72387
Implement timeouts in DHTProtocol.DHTStub
borzunov Jul 2, 2021
74fc538
Convert averager
borzunov Jul 3, 2021
7b17300
Convert hivemind.server.Server, test_{moe,training}.py
borzunov Jul 3, 2021
6a9e6fe
Convert test_auth.py
borzunov Jul 3, 2021
a42d111
Fix usages of removed dht_pb2.Ping{Request,Response} fields
borzunov Jul 3, 2021
382028e
Rename visible_host to announced_host in DecentralizedAverager
borzunov Jul 3, 2021
5e03bec
Explicitly clean up chld processes and Unix sockets after tests
borzunov Jul 3, 2021
5b3e015
Set default ping_n_retries = 5 everywhere
borzunov Jul 3, 2021
7e7ac43
Return to allowing QUIC by default
borzunov Jul 3, 2021
25f4c77
Make P2P listen only localhost by default
borzunov Jul 3, 2021
717b62f
Make P2P listen only TCP by default
borzunov Jul 3, 2021
07e4baf
Rename P2P params, improve docs
borzunov Jul 3, 2021
f9298e3
Allow passing p2p kwargs to DHT instead of an actual p2p instance (wh…
borzunov Jul 3, 2021
518031c
Fix readthedocs
borzunov Jul 3, 2021
3f35c2d
Fix bugs in test_dht_*
borzunov Jul 3, 2021
93ac771
Fix bugs after refactor
borzunov Jul 3, 2021
475d520
Refactor to use hivemind.p2p.Servicer
borzunov Jul 6, 2021
b7b82cb
Merge remote-tracking branch 'origin/master' into dht-over-p2p
borzunov Jul 6, 2021
5cffa06
Add docstring for Servicer
borzunov Jul 6, 2021
dec4f4c
Convert benchmark_averaging.py
borzunov Jul 6, 2021
48dad87
Merge remote-tracking branch 'origin/master' into dht-over-p2p
borzunov Jul 6, 2021
73c5369
Convert examples/albert, fix bugs, clarify docs
borzunov Jul 6, 2021
231bcee
Remove deprecated TODO/FIXMEs in DHTProtocol.shutdown
borzunov Jul 6, 2021
daeeba8
Fix AuthRPCWrapper with AuthRole.SERVICER
borzunov Jul 6, 2021
f93946a
Join proposed values for --initial_peers with space
borzunov Jul 6, 2021
80dbde8
Log visible maddrs in both run_{trainer,training_monitor}.py with color
borzunov Jul 6, 2021
4e8deb5
Fix comments by @mryab
borzunov Jul 8, 2021
2cd8152
Add docstring for StubBase
borzunov Jul 8, 2021
626569c
Clean up Unix sockets properly
borzunov Jul 8, 2021
a7f57ce
Handle SIGTERM to clean up Unix sockets in test_dht_node.py
borzunov Jul 8, 2021
342a07b
Don't pass DHTNode **kwargs to DHTProtocol
borzunov Jul 8, 2021
a788251
Pass DHTNode.create **kwargs to P2P
borzunov Jul 8, 2021
7706183
Don't pass P2P.create **kwargs directly to the daemon CLI anymore
borzunov Jul 8, 2021
96a57fe
Fix readthedocs
borzunov Jul 8, 2021
19b4833
Fix docstrings
borzunov Jul 8, 2021
d41a978
Fix unused channel_options in DecentralizedAverager
borzunov Jul 8, 2021
5f3afeb
Fix readthedocs
borzunov Jul 8, 2021
a4b31b6
Fix channel_options in DecentralizedAverager
borzunov Jul 8, 2021
3c7b903
Fix comments by @mryab
borzunov Jul 10, 2021
f60681e
Refactor test_dht_node.py and related files
borzunov Jul 10, 2021
7435541
Fix comments by @mryab
borzunov Jul 10, 2021
11eaaf7
Merge remote-tracking branch 'origin/master' into dht-over-p2p
borzunov Jul 10, 2021
1d6cc81
Return removed ping_n_attempts=10 in run_node()
borzunov Jul 10, 2021
4959ca0
Remove --dht_port argument from run_server.py
borzunov Jul 10, 2021
1774afb
Fix comments by @mryab
borzunov Jul 10, 2021
c16ad2a
Remove logging from p2p_daemon_bindings.ControlClient since it is too…
borzunov Jul 10, 2021
c8d3863
Update examples/albert/README.md
borzunov Jul 10, 2021
6c0bc1a
Fix comments with IP addresses
borzunov Jul 10, 2021
4cff374
Fix duplicated logging run_training_monitor.py
borzunov Jul 10, 2021
b87fbc9
Improve logging of peer's visible maddrs
borzunov Jul 10, 2021
1732a19
Add minor fixes to examples/albert/README.md
borzunov Jul 10, 2021
960431c
Fix loglevel of RPC failures in DHTProtocol
borzunov Jul 10, 2021
deaaf77
Fix and sort imports in dht/__init__.py
borzunov Jul 10, 2021
8538e4b
Fix "a trouble" -> "any trouble"
borzunov Jul 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions benchmarks/benchmark_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def sample_tensors(hid_size, num_layers):
def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
averaging_expiration: float, request_timeout: float, round_timeout: float,
hid_size: int, num_layers: int, spawn_dtime: float):
dht_root = hivemind.DHT(listen_on=f'{LOCALHOST}:*', start=True)
dht_root = hivemind.DHT(start=True)
initial_peers = dht_root.get_visible_maddrs()

num_groups = 2 ** int(round(math.log2(num_peers / target_group_size)))
nbits = int(round(math.log2(num_groups)))
peer_tensors = [sample_tensors(hid_size, num_layers)
Expand All @@ -45,9 +47,7 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,

def run_averager(index):
nonlocal successful_steps, total_steps, lock_stats
dht = hivemind.DHT(listen_on=f'{LOCALHOST}:*',
initial_peers=[f"{LOCALHOST}:{dht_root.port}"],
start=True)
dht = hivemind.DHT(initial_peers=initial_peers, start=True)
initial_bits = bin(index % num_groups)[2:].rjust(nbits, '0')
averager = hivemind.averaging.DecentralizedAverager(
peer_tensors[i], dht, prefix='my_tensor', initial_group_bits=initial_bits, listen_on=f"{LOCALHOST}:*",
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/benchmark_dht.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
logger.info("Creating peers...")
peers = []
for _ in trange(num_peers):
neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout,
listen_on=f'0.0.0.0:*')
neighbors = sum([peer.get_visible_maddrs()
mryab marked this conversation as resolved.
Show resolved Hide resolved
for peer in random.sample(peers, min(initial_peers, len(peers)))], [])
peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout)
peers.append(peer)

store_peer, get_peer = peers[-2:]
Expand Down
29 changes: 20 additions & 9 deletions examples/albert/arguments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Optional, List
from dataclasses import dataclass, field
from typing import Optional, List

from transformers import TrainingArguments

Expand All @@ -11,11 +11,26 @@ class BaseTrainingArguments:
)
initial_peers: List[str] = field(
default_factory=list,
metadata={"help": "One or more peers (comma-separated) that will welcome you into the collaboration"}
metadata={"help":
"Multiaddrs of the peers that will welcome you into the existing collaboration. "
"Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY"}
)
dht_listen_on: str = field(
default="[::]:*",
metadata={"help": "Network interface used for incoming DHT communication. Default: all ipv6"}
use_ipfs: bool = field(
default=False,
metadata={"help":
"Use IPFS to find initial_peers. If enabled, you only need to provide /p2p/XXXX part of the multiaddrs "
"for the initial_peers (no need to specify a particular IPv4/IPv6 host and port)"}
)
host_maddrs: List[str] = field(
default_factory=lambda: ['/ip4/0.0.0.0/tcp/0', '/ip4/0.0.0.0/udp/0/quic'],
metadata={"help":
"Multiaddrs to listen for external connections from other p2p instances. "
"Defaults to all IPv4 interfaces with TCP and QUIC (over UDP) protocols: "
"/ip4/0.0.0.0/tcp/0 /ip4/0.0.0.0/udp/0/quic"}
)
announce_maddrs: List[str] = field(
default_factory=list,
metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"}
)


Expand Down Expand Up @@ -97,10 +112,6 @@ class CollaborationArguments(AveragerArguments, CollaborativeOptimizerArguments,
default=600,
metadata={"help": "Statistics will be removed if not updated in this many seconds"}
)
endpoint: Optional[str] = field(
default=None,
metadata={"help": "This node's IP for inbound connections, used when running from behind a proxy"}
)


@dataclass
Expand Down
19 changes: 11 additions & 8 deletions examples/albert/run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from torch_optimizer import Lamb

import hivemind
import utils
from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingArguments
import metrics_utils


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -130,7 +130,7 @@ def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
self.total_samples_processed += self.samples
samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
statistics = metrics_utils.LocalMetrics(
statistics = utils.LocalMetrics(
step=self.collaborative_optimizer.local_step,
samples_per_second=samples_per_second,
samples_accumulated=self.samples,
Expand Down Expand Up @@ -219,13 +219,16 @@ def main():

opt, scheduler = get_optimizer_and_scheduler(training_args, model)

validators, local_public_key = metrics_utils.make_validators(
validators, local_public_key = utils.make_validators(
collaboration_args_dict['experiment_prefix'])
dht = hivemind.DHT(
start=True, initial_peers=collaboration_args_dict.pop('initial_peers'),
listen=not collaboration_args_dict['client_mode'],
listen_on=collaboration_args_dict.pop('dht_listen_on'),
endpoint=collaboration_args_dict.pop('endpoint'), record_validators=validators)
dht = hivemind.DHT(start=True,
initial_peers=collaboration_args_dict.pop('initial_peers'),
listen=not collaboration_args_dict['client_mode'],
record_validators=validators,
use_ipfs=collaboration_args_dict.pop('use_ipfs'),
host_maddrs=collaboration_args_dict.pop('host_maddrs'),
announce_maddrs=collaboration_args_dict.pop('announce_maddrs'))
utils.log_visible_maddrs(dht.get_visible_maddrs())

total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
statistics_expiration = collaboration_args_dict.pop('statistics_expiration')
Expand Down
44 changes: 24 additions & 20 deletions examples/albert/run_training_monitor.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
#!/usr/bin/env python

from dataclasses import dataclass, field, asdict
import subprocess
import time
from dataclasses import asdict, dataclass, field
from ipaddress import ip_address
from typing import Optional

import torch
import wandb
from torch_optimizer import Lamb
from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
import wandb
from whatsmyip.providers import GoogleDnsProvider
from whatsmyip.ip import get_ip
from whatsmyip.providers import GoogleDnsProvider

from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
import hivemind
import utils
from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
from hivemind.utils.logging import get_logger
import metrics_utils

logger = get_logger(__name__)

Expand All @@ -27,10 +28,9 @@ class CoordinatorArguments(BaseTrainingArguments):
new workers still can join the collaboration via alive initial peers' addresses.
Specify initial_peers argument for that purpose
"""
address: Optional[str] = field(
default=None,
metadata={"help": "This machine's network address. Use public IP for global experiments, "
"local address for private runs"}
use_google_dns: bool = field(
default=False,
metadata={"help": "Use Google DNS to determine our public IP address (and add it to --announce_maddrs)"}
mryab marked this conversation as resolved.
Show resolved Hide resolved
)
refresh_period: float = field(
default=30,
Expand Down Expand Up @@ -139,17 +139,21 @@ def upload_checkpoint(self, current_loss):
parser = HfArgumentParser((CoordinatorArguments, CollaborativeOptimizerArguments, AveragerArguments))
coordinator_args, collab_optimizer_args, averager_args = parser.parse_args_into_dataclasses()

if coordinator_args.address is None:
logger.warning("No address specified. Attempting to infer address from DNS.")
coordinator_args.address = get_ip(GoogleDnsProvider)
if coordinator_args.use_google_dns:
address = get_ip(GoogleDnsProvider)
logger.info(f"Google DNS responds that our IP address is {address}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.info(f"Google DNS responds that our IP address is {address}")
logger.info(f"Received IP address from Google DNS: {address}")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue that the proposed message is more misleading (looks like the DNS tells us the IP address for some domain since that's what DNS usually does). However, I don't mind changing the message to another, more clear one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about "Received IP address of monitor from Google DNS: {address}"?

Copy link
Member Author

@borzunov borzunov Jul 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say monitor is an ambiguous term. I've changed it to Received public IP address of this machine from Google DNS.

version = ip_address(address).version
coordinator_args.announce_maddrs += [f'/ip{version}/{address}/tcp/0', f'/ip{version}/{address}/udp/0/quic']

experiment_prefix = coordinator_args.experiment_prefix
validators, local_public_key = metrics_utils.make_validators(experiment_prefix)
dht = hivemind.DHT(start=True, listen_on=coordinator_args.dht_listen_on,
endpoint=f"{coordinator_args.address}:*", initial_peers=coordinator_args.initial_peers,
record_validators=validators)

logger.info(f"Running DHT root at {coordinator_args.address}:{dht.port}")
validators, local_public_key = utils.make_validators(experiment_prefix)
dht = hivemind.DHT(start=True,
initial_peers=coordinator_args.initial_peers,
record_validators=validators
use_ipfs=coordinator_args.use_ipfs,
host_maddrs=coordinator_args.host_maddrs,
announce_maddrs=coordinator_args.announce_maddrs)
utils.log_visible_maddrs(dht.get_visible_maddrs())

if coordinator_args.wandb_project is not None:
wandb.init(project=coordinator_args.wandb_project)
Expand All @@ -162,7 +166,7 @@ def upload_checkpoint(self, current_loss):
metrics_dict = dht.get(experiment_prefix + '_metrics', latest=True)
if metrics_dict is not None:
metrics_dict = metrics_dict.value
metrics = [metrics_utils.LocalMetrics.parse_obj(metrics_dict[peer].value)
metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value)
for peer in metrics_dict]
latest_step = max(item.step for item in metrics)
if latest_step != current_step:
Expand All @@ -184,6 +188,7 @@ def upload_checkpoint(self, current_loss):
num_samples += item.samples_accumulated
sum_mini_steps += item.mini_steps
current_loss = sum_loss / sum_mini_steps
logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")

if coordinator_args.wandb_project is not None:
wandb.log({
Expand All @@ -198,6 +203,5 @@ def upload_checkpoint(self, current_loss):
checkpoint_handler.save_state(current_step)
if checkpoint_handler.is_time_to_upload():
checkpoint_handler.upload_checkpoint(current_loss)
logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
logger.debug("Peer is still alive...")
time.sleep(coordinator_args.refresh_period)
20 changes: 19 additions & 1 deletion examples/albert/metrics_utils.py → examples/albert/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import Dict, List, Tuple

from multiaddr import Multiaddr
from pydantic import BaseModel, StrictFloat, confloat, conint

from hivemind.dht.crypto import RSASignatureValidator
from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
from hivemind.dht.validation import RecordValidatorBase
from pydantic import BaseModel, StrictFloat, confloat, conint
from hivemind.utils.logging import get_logger


logger = get_logger(__name__)


class LocalMetrics(BaseModel):
Expand All @@ -23,3 +29,15 @@ def make_validators(experiment_prefix: str) -> Tuple[List[RecordValidatorBase],
validators = [SchemaValidator(MetricSchema, prefix=experiment_prefix),
signature_validator]
return validators, signature_validator.local_public_key


class TextStyle:
BOLD = '\033[1m'
BLUE = '\033[34m'
RESET = '\033[0m'


def log_visible_maddrs(visible_maddrs: List[Multiaddr]) -> None:
initial_peers_str = ' '.join(str(addr) for addr in visible_maddrs)
logger.info(f"Running a DHT node. To connect, supply "
f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}")
26 changes: 23 additions & 3 deletions hivemind/averaging/averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import weakref
from concurrent.futures.thread import ThreadPoolExecutor
from dataclasses import asdict
from ipaddress import ip_address
from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator

import grpc
Expand All @@ -30,6 +31,7 @@
from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, split_for_streaming, combine_from_streaming
from hivemind.utils.networking import choose_ip_address, strip_port
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration

Expand Down Expand Up @@ -68,6 +70,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
:param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
:param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
:param announced_host: visible IP address the averager will announce for external connections from other peers.
If None, the address will be chosen from p2p.get_visible_maddrs() (global IPv4 addresses are preferred)
:param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
:param kwargs: extra parameters forwarded to grpc.aio.server
Expand Down Expand Up @@ -102,7 +106,8 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start:
throughput: Optional[float] = None, min_vector_size: int = 0,
auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
channel_options: Optional[Sequence[Tuple[str, Any]]] = None,
announced_host: Optional[str] = None,
borzunov marked this conversation as resolved.
Show resolved Hide resolved
channel_options: Sequence[Tuple[str, Any]] = (),
shutdown_timeout: float = 5, **kwargs):
assert '.' not in prefix, "group prefix must be a string without trailing '.'"
assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \
Expand All @@ -122,6 +127,9 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start:
else:
self.mode = AveragingMode.NODE

if announced_host is None:
announced_host = self._choose_announced_host()
self.announced_host = announced_host
self.channel_options = channel_options
self.daemon = daemon

Expand Down Expand Up @@ -163,6 +171,17 @@ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start:
if start:
self.run_in_background(await_ready=True)

def _choose_announced_host(self) -> Hostname:
announced_host = strip_port(self.listen_on).strip('[]') # Stripping square brackets for IPv6
if ip_address(announced_host) not in [ip_address('0.0.0.0'), ip_address('::')]:
return announced_host

maddrs = self.dht.get_visible_maddrs()
announced_host = choose_ip_address(maddrs)
logger.info(f'Choosing IP {announced_host} as endpoint for DecentralizedAverager '
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.info(f'Choosing IP {announced_host} as endpoint for DecentralizedAverager '
logger.debug(f'Choosing IP {announced_host} as endpoint for DecentralizedAverager '

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this is something that the user will rarely want to directly connect to, since we're mainly interested in DHT multiaddrs)

Copy link
Member Author

@borzunov borzunov Jul 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the DHT works with a set of multiaddrs for one peer, we need to select only one IP for the averager to announce here (until we convert the averager itself to libp2p, which is outside of scope of this PR).

By default, this code chooses a public IPv4 address (if available) among all IPs in the visible multiaddrs. However, I'm afraid that this is not always the right choice. If it is not, a user seeing this message will get a chance to spot that the address is wrong and set the announced averager endpoint manually via the announced_host parameter.

In any case, this message will be removed soon when the averager itself will be converted to work over libp2p entirely (that is, with several multiaddrs for a peer).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, though I'm generally against logging too much info in case everything proceeds as normal. If the user does run into errors, isn't it a logical next step to switch to debug and see what goes wrong?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a potential middle-ground, we could replicate the behavior from hivemind.optim:

logger.log(f'Choosing IP {announced_host} as endpoint for DecentralizedAverager',
                 level=INFO if self.verbose else DEBUG)

(i have no preferences here, let @borzunov decide)

Copy link
Member Author

@borzunov borzunov Jul 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would leave it to be logging.info() for a short transition period when the DHT already uses libp2p but the averager does not use it yet. After this period, this line will be removed.

f'from visible multiaddrs {maddrs}')
return announced_host

@property
def port(self) -> Optional[Port]:
return self._port.value if self._port.value != 0 else None
Expand All @@ -183,7 +202,7 @@ def allow_state_sharing(self, value: bool):
def endpoint(self) -> Optional[Endpoint]:
if self.listen and self._averager_endpoint is None:
assert self.port is not None, "Averager is not running yet"
self._averager_endpoint = f"{self.dht.get_visible_address()}:{self.port}"
self._averager_endpoint = f"{self.announced_host}:{self.port}"
logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
return self._averager_endpoint

Expand Down Expand Up @@ -499,7 +518,8 @@ async def _load_state_from_peers(self, future: MPFuture):
logger.info(f"Downloading parameters from peer {peer}")
stream = None
try:
stub = ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
stub = ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True,
options=self.channel_options)
stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
current_tensor_parts, tensors = [], []
async for message in stream:
Expand Down
Loading