Skip to content

Commit

Permalink
Minor style updates in examples (#321)
Browse files Browse the repository at this point in the history
Co-authored-by: Max Ryabinin <[email protected]>
  • Loading branch information
justheuristic and mryab authored Jul 15, 2021
1 parent 11db5fd commit def7038
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/albert/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class CollaborationArguments(CollaborativeOptimizerArguments, BaseTrainingArgume
default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
)
backup_every_steps: int = field(
default=10, metadata={"help": "In case of NaN, training restore from a backup updated with this frequency."}
default=10, metadata={"help": "Frequency of backups to restore from in case of encountering NaN values"}
)


Expand Down
14 changes: 7 additions & 7 deletions examples/albert/run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pickle
from dataclasses import asdict
from pathlib import Path
from typing import Any

import torch
import transformers
Expand Down Expand Up @@ -97,8 +96,8 @@ def get_optimizer_and_scheduler(training_args, model):

class CollaborativeCallback(transformers.TrainerCallback):
"""
This callback monitors and reports collaborative training progress,
In case of a catastrophic failure, it can also revert training to a backup
This callback monitors and reports collaborative training progress.
In case of a catastrophic failure, it can also revert training to a backup.
"""

def __init__(
Expand Down Expand Up @@ -153,6 +152,7 @@ def on_step_end(
)
logger.info(f"Step {self.collaborative_optimizer.local_step}")
logger.info(f"Your current contribution: {self.total_samples_processed} samples")
logger.info(f"Performance: {samples_per_second} samples per second.")
if self.steps:
logger.info(f"Local loss: {self.loss / self.steps}")
if self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
Expand Down Expand Up @@ -181,16 +181,16 @@ def params_are_finite(self):
return True

@torch.no_grad()
def backup_state(self) -> Any:
def backup_state(self) -> bytes:
return pickle.dumps(
{"model": self.model.state_dict(), "training": self.collaborative_optimizer.opt.state_dict()}
{"model": self.model.state_dict(), "optimizer": self.collaborative_optimizer.opt.state_dict()}
)

@torch.no_grad()
def restore_from_backup(self, backup):
def restore_from_backup(self, backup: bytes):
state = pickle.loads(backup)
self.model.load_state_dict(state["model"])
self.collaborative_optimizer.opt.load_state_dict(state["training"])
self.collaborative_optimizer.opt.load_state_dict(state["optimizer"])


class NoOpScheduler(LRSchedulerBase):
Expand Down
3 changes: 1 addition & 2 deletions examples/albert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
unique_addrs = {addr["p2p"] for addr in visible_maddrs}
initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
else:
available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr]
available_ips += [Multiaddr(addr) for addr in visible_maddrs if "ip6" in addr]
available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr or "ip6" in addr]
if available_ips:
preferred_ip = choose_ip_address(available_ips)
selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)]
Expand Down
1 change: 1 addition & 0 deletions hivemind/optim/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
def step(self, *args, **kwargs):
with self.lock_parameters:
loss = self.opt.step(*args, **kwargs)

self.local_step += 1
if self.local_step % self.averaging_step_period == 0:
self.update_event.set()
Expand Down
4 changes: 2 additions & 2 deletions hivemind/p2p/p2p_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,11 @@ async def _add_protobuf_stream_handler(
name: str,
handler: Callable[[TInputStream, P2PContext], TOutputStream],
input_protobuf_type: type,
max_prefetch: int = 0,
max_prefetch: int = 5,
) -> None:
"""
:param max_prefetch: Maximum number of items to prefetch from the request stream.
``max_prefetch <= 0`` means unlimited (default).
``max_prefetch <= 0`` means unlimited.
:note: Since the cancel messages are sent via the input stream,
they will not be received while the prefetch buffer is full.
Expand Down

0 comments on commit def7038

Please sign in to comment.