Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use logging in benchmarks, fix libp2p-related issues #280

Merged
merged 11 commits into from
Jun 17, 2021
36 changes: 23 additions & 13 deletions benchmarks/benchmark_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
import torch

import hivemind
from hivemind.utils import LOCALHOST, increase_file_limit
from hivemind.utils import LOCALHOST, increase_file_limit, get_logger
from hivemind.proto import runtime_pb2


logger = get_logger(__name__)


def sample_tensors(hid_size, num_layers):
tensors = []
for i in range(num_layers):
Expand Down Expand Up @@ -38,8 +41,11 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
peer_tensors = [sample_tensors(hid_size, num_layers)
for _ in range(num_peers)]
processes = {dht_root}
lock_stats = threading.Lock()
successful_steps = total_steps = 0

def run_averager(index):
nonlocal successful_steps, total_steps, lock_stats
dht = hivemind.DHT(listen_on=f'{LOCALHOST}:*',
initial_peers=[f"{LOCALHOST}:{dht_root.port}"],
start=True)
Expand All @@ -50,11 +56,17 @@ def run_averager(index):
averaging_expiration=averaging_expiration, request_timeout=request_timeout, start=True)
processes.update({dht, averager})

print(end=f'<started {index}>\n', flush=True)
for _ in range(num_rounds):
success = averager.step(timeout=round_timeout)
print(end=('+' if success else '-'), flush=True)
print(end=f'<finished {index}>\n', flush=True)
logger.info(f'Averager {index}: started on endpoint {averager.endpoint}, group_bits: {averager.get_group_bits()}')
for step in range(num_rounds):
try:
success = averager.step(timeout=round_timeout) is not None
except:
success = False
with lock_stats:
successful_steps += int(success)
total_steps += 1
logger.info(f"Averager {index}: {'finished' if success else 'failed'} step {step}")
logger.info(f"Averager {index}: done.")

threads = []
for i in range(num_peers):
Expand All @@ -67,10 +79,8 @@ def run_averager(index):
for thread in threads:
thread.join()

print(f"\ntest run took {time.time() - t:.3f} seconds")

for process in processes:
process.terminate()
logger.info(f"Benchmark finished in {time.time() - t:.3f} seconds.")
logger.info(f"Success rate: {successful_steps / total_steps} ({successful_steps} out of {total_steps} attempts)")


if __name__ == "__main__":
Expand All @@ -80,9 +90,9 @@ def run_averager(index):
parser.add_argument('--num_rounds', type=int, default=5, required=False)
parser.add_argument('--hid_size', type=int, default=256, required=False)
parser.add_argument('--num_layers', type=int, default=3, required=False)
parser.add_argument('--averaging_expiration', type=float, default=15, required=False)
parser.add_argument('--round_timeout', type=float, default=30, required=False)
parser.add_argument('--request_timeout', type=float, default=3, required=False)
parser.add_argument('--averaging_expiration', type=float, default=5, required=False)
parser.add_argument('--round_timeout', type=float, default=15, required=False)
parser.add_argument('--request_timeout', type=float, default=1, required=False)
parser.add_argument('--spawn_dtime', type=float, default=0.1, required=False)
parser.add_argument('--increase_file_limit', action="store_true")
args = vars(parser.parse_args())
Expand Down
16 changes: 8 additions & 8 deletions benchmarks/benchmark_dht.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
wait_after_request: float, wait_before_read: float, wait_timeout: float, expiration: float):
random.seed(random_seed)

print("Creating peers...")
logger.info("Creating peers...")
peers = []
for _ in trange(num_peers):
neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
Expand All @@ -32,10 +32,10 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b

expert_uids = list(set(f"expert.{random.randint(0, 999)}.{random.randint(0, 999)}.{random.randint(0, 999)}"
for _ in range(num_experts)))
print(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
logger.info(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
random.shuffle(expert_uids)

print(f"Storing experts to dht in batches of {expert_batch_size}...")
logger.info(f"Storing experts to dht in batches of {expert_batch_size}...")
successful_stores = total_stores = total_store_time = 0
benchmark_started = time.perf_counter()
endpoints = []
Expand All @@ -52,8 +52,8 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
successful_stores += sum(successes)
time.sleep(wait_after_request)

print(f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})")
print(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
logger.info(f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})")
logger.info(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
time.sleep(wait_before_read)

if time.perf_counter() - benchmark_started > expiration:
Expand All @@ -74,11 +74,11 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
if time.perf_counter() - benchmark_started > expiration:
logger.warning("keys expired midway during get requests. If that isn't desired, increase expiration_time param")

print(f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
print(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")
logger.info(f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
logger.info(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")

alive_peers = [peer.is_alive() for peer in peers]
print(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
logger.info(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")


if __name__ == "__main__":
Expand Down
6 changes: 5 additions & 1 deletion benchmarks/benchmark_tensor_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
from hivemind.utils.logging import get_logger


logger = get_logger(__name__)


def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:
Expand All @@ -29,4 +33,4 @@ def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionTyp
for i in range(args.num_iters):
tm += benchmark_compression(X, compression_type)
tm /= args.num_iters
print(f"Compression type: {name}, time: {tm}")
logger.info(f"Compression type: {name}, time: {tm}")
36 changes: 20 additions & 16 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,23 @@
from hivemind import find_open_port
from hivemind.server import layers
from hivemind.utils.threading import increase_file_limit
from hivemind.utils.logging import get_logger


logger = get_logger(__name__)


def print_device_info(device=None):
"""Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
device = torch.device(device or ('cuda' if torch.cuda.is_available() else 'cpu'))
print('Using device:', device)
logger.info(f'Using device: {device}')

# Additional Info when using cuda
if device.type == 'cuda':
print(torch.cuda.get_device_name(0))
print('Memory Usage:')
print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 'GB')
print('Cached: ', round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), 'GB')
logger.info(torch.cuda.get_device_name(0))
logger.info(f'Memory Usage:')
logger.info(f'Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB')
logger.info(f'Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB')


def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
Expand Down Expand Up @@ -111,25 +115,25 @@ def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num
abs(timestamps[key2] - timestamps[key1]) if (key1 in timestamps and key2 in timestamps) else float('nan')
total_examples = batch_size * num_clients * num_batches_per_client

print('\n' * 3)
print("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
print(f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, max_batch_size={max_batch_size},"
logger.info("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
logger.info(f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, max_batch_size={max_batch_size},"
f" expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}")
print(f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
logger.info(f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
f"batch_size={batch_size}, backprop={backprop}")

print("Results: ")
print(f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
logger.info("Results: ")
logger.info(f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
f"{time_between('created_experts', 'server_ready') :.3f} s. networking)")
print(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
print(f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
logger.info(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
logger.info(f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s.")
print(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
logger.info(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
if benchmarking_failed.is_set():
print("Note: benchmark code failed, timing/memory results only indicate time till failure!")
logger.info("Note: benchmark code failed, timing/memory results only indicate time till failure!")
print_device_info(device)
print(flush=True)
sys.stdout.flush()
sys.stderr.flush()

assert not benchmarking_failed.is_set()

Expand Down
30 changes: 18 additions & 12 deletions examples/albert/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,24 @@ This tutorial will walk you through the steps to set up collaborative training w
## Running an experiment
- Run the first DHT peer to welcome trainers and record training statistics (e.g. loss, performance):
- In this example, we use [wandb.ai](https://wandb.ai/site) to plot training metrics; If you're unfamiliar with Weights & Biases, here's a [quickstart tutorial](https://docs.wandb.ai/quickstart).
- Run `python run_first_peer.py --listen_on '[::]:*' --experiment_prefix NAME_YOUR_EXPERIMENT --wandb_project WANDB_PROJECT_HERE`
- Run `python run_first_peer.py --dht_listen_on '[::]:*' --experiment_prefix NAME_YOUR_EXPERIMENT --wandb_project WANDB_PROJECT_HERE`
- `NAME_YOUR_EXPERIMENT` must be a unique name of this training run, e.g. `my-first-albert`. It cannot contain `.` due to naming conventions.
- `WANDB_PROJECT_HERE` is a name of wandb project used to track training metrics. Multiple experiments can have the same project name.
- This peer will run a DHT node on a certain IP/port (`Running DHT root at ...`). You will need this address for next steps
```
+ python ./run_first_peer.py --listen_on '[::]:31209' --experiment_prefix ysda_albert_v10 --wandb_project Demo-run
[2021/04/19 02:30:06.051][WARN][root.<module>:36] No address specified. Attempting to infer address from DNS.
[2021/04/19 02:30:06.088][INFO][root.<module>:44] Running DHT root at 18.217.13.97:31209
wandb: Currently logged in as: ??? (use `wandb login --relogin` to force relogin)
wandb: Tracking run with wandb version 0.10.26
wandb: Syncing run wandering-sky-58
wandb: View project at https://wandb.ai/yhn112/Demo-run
wandb: 🚀 View run at https://wandb.ai/yhn112/Demo-run/runs/38ygvt3n
wandb: Run data is saved locally in /home/hivemind/examples/albert/wandb/run-20210419_023006-38ygvt3n
+ python run_first_peer.py --dht_listen_on '[::]:*' --experiment_prefix my-albert-v1 --wandb_project Demo-run
[2021/06/17 16:26:35.931][WARN][root.<module>:140] No address specified. Attempting to infer address from DNS.
[2021/06/17 16:26:36.083][INFO][root.<module>:149] Running DHT root at 193.106.95.184:38319
wandb: Currently logged in as: XXX (use `wandb login --relogin` to force relogin)
wandb: Tracking run with wandb version 0.10.32
wandb: Syncing run dry-mountain-2
wandb: View project at https://wandb.ai/XXX/Demo-run
wandb: View run at https://wandb.ai/XXX/Demo-run/runs/YYY
wandb: Run data is saved locally in /path/to/run/data
wandb: Run `wandb offline` to turn off syncing.
[2021/04/19 02:26:41.064][INFO][optim.collaborative.fetch_collaboration_state:323] Found no active peers: None
[2021/04/19 02:26:44.068][INFO][optim.collaborative.fetch_collaboration_state:323] Found no active peers: None
...
[2021/04/19 02:37:37.246][INFO][root.<module>:74] 11.05164
[2021/04/19 02:39:37.441][INFO][root.<module>:74] 11.03771
[2021/04/19 02:40:37.541][INFO][root.<module>:74] 11.02886
Expand All @@ -37,19 +40,22 @@ wandb: Run `wandb offline` to turn off syncing.
- if necessary, specify paths: `--dataset_path ./path/to/unpacked/data --tokenizer ./path/to/tokenizer/config` (see [default paths](https://github.com/learning-at-home/hivemind/blob/collaborative_albert_example/examples/albert/run_trainer.py#L63-L69) for reference)
- run:
```shell
CUDA_VISIBLE_DEVICES=0 HIVEMIND_THREADS=64 python ./hivemind/examples/albert/run_trainer.py \
HIVEMIND_THREADS=64 python run_trainer.py \
--experiment_prefix SAME_AS_IN_RUN_FIRST_PEER --initial_peers ONE_OR_MORE_PEERS --seed 42 \
--logging_first_step --logging_steps 100 --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
```
Here, `ONE_OR_MORE_PEERS` stands for either your coordinator endpoint (e.g. `123.123.123.123:1337`), an endpoint of any pre-existing trainer or multiple endpoints for stability. See tips & tricks section below for more information on setting up collaborative training.

As the peer begins training, it will periodically report training logs in the following form:
```
{'loss': 4.3577, 'learning_rate': 0.001318944, 'epoch': 0.0}
[...][INFO][...] Collaboration accumulated 448 samples from 17 peers; ETA 18.88 seconds (refresh in 15.73s.)
[...][INFO][...] Collaboration accumulated 4096 samples from 16 peers; ETA 0.00 seconds (refresh in 0.50s.)
[...][INFO][optim.collaborative.step:195] Averaged tensors successfully with 17 peers
[...][INFO][optim.collaborative.step:211] Optimizer step: done!
06/17/2021 18:58:23 - INFO - __main__ - Step 0
06/17/2021 18:58:23 - INFO - __main__ - Your current contribution: 892 samples
06/17/2021 18:58:23 - INFO - __main__ - Local loss: 11.023

```

__Sanity check:__ a healthy peer will periodically report `Averaged tensors successfully with [N > 1]` peers.
Expand Down
10 changes: 6 additions & 4 deletions examples/albert/tokenize_wikitext103.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python
""" This script builds a pre-tokenized compressed representation of wikitext103 using huggingface/datasets """
import random
from collections import defaultdict
from functools import partial
from multiprocessing import cpu_count

Expand All @@ -10,6 +9,9 @@
from transformers import AlbertTokenizerFast


COLUMN_NAMES = ('attention_mask', 'input_ids', 'sentence_order_label', 'special_tokens_mask', 'token_type_ids')


def create_instances_from_document(tokenizer, document, max_seq_length):
"""Creates `TrainingInstance`s for a single document."""
# We DON'T just concatenate all of the tokens from a document into a long
Expand Down Expand Up @@ -76,14 +78,14 @@ def tokenize_function(tokenizer, examples):
# Remove empty texts
texts = (text for text in examples["text"] if len(text) > 0 and not text.isspace())

new_examples = defaultdict(list)
new_examples = {col: [] for col in COLUMN_NAMES}

for text in texts:
instances = create_instances_from_document(tokenizer, text, max_seq_length=512)
for instance in instances:
for key, value in instance.items():
new_examples[key].append(value)

return new_examples


Expand All @@ -96,7 +98,7 @@ def tokenize_function(tokenizer, examples):
tokenized_datasets = wikitext.map(
partial(tokenize_function, tokenizer),
batched=True,
num_proc=cpu_count(),
num_proc=8,
remove_columns=["text"],
)

Expand Down
7 changes: 6 additions & 1 deletion hivemind/client/averaging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,12 @@ def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.
:param get_current_state_ref: a WeakMethod wrapped around DecentralizedAverager.get_current_state (instance-bound)
"""
while True:
trigger, future = pipe.recv()
try:
trigger, future = pipe.recv()
except BaseException as e:
logger.debug(f"Averager background thread finished: {repr(e)}")
break

if trigger == '_SHUTDOWN':
break

Expand Down
2 changes: 1 addition & 1 deletion hivemind/server/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def iterate_minibatches_from_pools(self, timeout=None):
with DefaultSelector() as selector:
for pool in self.pools:
selector.register(pool.batch_receiver, EVENT_READ, pool)
# selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)

while True:
# wait until at least one batch_receiver becomes available
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def libp2p_build_install():
def libp2p_download_install():
install_path = os.path.join(here, 'hivemind', 'hivemind_cli')
binary_path = os.path.join(install_path, 'p2pd')
if 'p2pd' not in os.listdir(install_path) or md5(binary_path) != P2PD_CHECKSUM:
if not os.path.exists(binary_path) or md5(binary_path) != P2PD_CHECKSUM:
print('Downloading Peer to Peer Daemon')
url = f'https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd'
urllib.request.urlretrieve(url, binary_path)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_p2p_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,9 @@ def handler(arg, key):
await server_replica2.stop_listening()

# Primary does not handle replicas protocols
with pytest.raises(asyncio.IncompleteReadError):
with pytest.raises(Exception):
await client.call_peer_handler(server_id, handler_name + '1', b'')
with pytest.raises(asyncio.IncompleteReadError):
with pytest.raises(Exception):
await client.call_peer_handler(server_id, handler_name + '2', b'')

await server_primary.stop_listening()
Expand Down