Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gradient buffers to CollaborativeOptimizer #220

Merged
merged 9 commits into from
Apr 14, 2021
78 changes: 67 additions & 11 deletions hivemind/client/optim/collaborative.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from dataclasses import dataclass
from threading import Thread, Lock, Event
from typing import Optional, Type
from typing import Optional, Iterator
import logging

import torch
Expand All @@ -11,7 +11,7 @@
from hivemind.dht import DHT
from hivemind.client.optim.base import DecentralizedOptimizerBase
from hivemind.client.averaging.training import TrainingAverager
from hivemind.utils import get_logger, get_dht_time, run_in_background, ValueWithExpiration
from hivemind.utils import get_logger, get_dht_time, ValueWithExpiration
from hivemind.client.optim.performance_ema import PerformanceEMA

logger = get_logger(__name__)
Expand Down Expand Up @@ -47,15 +47,14 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):

:note: This optimizer behaves unlike regular pytorch optimizers in two ways:

- calling .step will periodially zero-out gradients w.r.t. model parameters after each step
- calling .step will periodically zero-out gradients w.r.t. model parameters after each step
- it may take multiple .step calls without updating model parameters, waiting for peers to accumulate enough samples

:param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
:param dht: a running hivemind.DHT daemon connected to other peers
:param prefix: a common prefix for all metadata stored by CollaborativeOptimizer in the DHT
:param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
:param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
:param target_group_size: maximum group size for DecentralizedAverager's all-reduce
:param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
:param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
:param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
Expand All @@ -69,6 +68,12 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
:param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
:param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
:param scheduler: if specified, use this scheduler to update optimizer learning rate
:param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
:param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
the cost of extra time per step. If reuse_gradient_accumulators is True, this parameter has no effect.
:param kwargs: additional parameters forwarded to DecentralizedAverager
:note: if you are using CollaborativeOptimizer with a lr_scheduler, it is recommended to pass this scheduler
explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
"""
Expand All @@ -78,14 +83,17 @@ def __init__(self, opt: torch.optim.Optimizer, *, dht: DHT, prefix: str, target_
min_refresh_period: float = 0.5, max_refresh_period: float = 30, default_refresh_period: float = 3,
expected_drift_peers: float = 3, expected_drift_rate: float = 0.2, performance_ema_alpha: float = 0.1,
metadata_expiration: float = 30.0, averaging_timeout: Optional[float] = None, verbose: bool = False,
**kwargs):
reuse_grad_buffers: bool = False, accumulate_grads_on: Optional[torch.device] = None, **kwargs):
super().__init__(opt, dht)
if reuse_grad_buffers and accumulate_grads_on is not None:
logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
self.prefix, self.scheduler = prefix, scheduler
self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
self.min_refresh_period, self.max_refresh_period, self.default_refresh_period =\
min_refresh_period, max_refresh_period, default_refresh_period
self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
self.status_loglevel = logging.INFO if verbose else logging.DEBUG
self.averager = self._make_averager(**kwargs)

Expand Down Expand Up @@ -134,9 +142,12 @@ def step(self, batch_size: Optional[int] = None, **kwargs):
:param batch_size: optional override for batch_size_per_step from init
:note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
"""
if batch_size is not None and self.batch_size_per_step is None:
raise ValueError("Please either set batch_size_per_step parameter at init or provide batch_size in .step")
batch_size = self.batch_size_per_step if batch_size is None else batch_size
if self.batch_size_per_step is None:
if batch_size is None:
raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
logger.log(self.status_loglevel, f"Setting default batch_size_per_step to {batch_size}")
self.batch_size_per_step = batch_size
justheuristic marked this conversation as resolved.
Show resolved Hide resolved
batch_size = batch_size if batch_size is not None else self.batch_size_per_step

if not self.is_synchronized:
self.load_state_from_peers()
Expand All @@ -146,6 +157,7 @@ def step(self, batch_size: Optional[int] = None, **kwargs):
logger.warning(f"Training step took {get_dht_time() - self.last_step_time}, "
f"but metadata expired in {self.metadata_expiration} s.")

self.accumulate_grads_(batch_size)
with self.lock_local_progress:
self.local_samples_accumulated += batch_size
self.local_steps_accumulated += 1
Expand All @@ -164,6 +176,9 @@ def step(self, batch_size: Optional[int] = None, **kwargs):
return

with self.performance_ema.pause(), self.lock_collaboration_state:
# divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
self.apply_accumulated_grads_(scale_by=1. / self.local_steps_accumulated)

if self.collaboration_state.num_peers > 1:
mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
weight = self.local_samples_accumulated / mean_samples_per_worker
Expand All @@ -176,6 +191,7 @@ def step(self, batch_size: Optional[int] = None, **kwargs):

self.opt.step()
self.opt.zero_grad()
self.reset_accumulated_grads_()
self.local_samples_accumulated = self.local_steps_accumulated = 0
self.collaboration_state.register_step()
self.collaboration_state_updated.set()
Expand All @@ -184,6 +200,46 @@ def step(self, batch_size: Optional[int] = None, **kwargs):
logger.log(self.status_loglevel, f"Optimizer step: done!")
return output

def _grad_buffers(self) -> Iterator[torch.Tensor]:
""" pytorch-internal gradient buffers """
for param_group in self.opt.param_groups:
for param in param_group['params']:
if param.grad is None:
yield torch.zeros_like(param)
else:
yield param.grad

@torch.no_grad()
def accumulated_grads(self) -> Iterator[torch.Tensor]:
""" local gradient accumulators """
if self.reuse_grad_buffers:
yield from self._grad_buffers()
elif self._grads is None:
with torch.no_grad():
self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
return self._grads

@torch.no_grad()
def accumulate_grads_(self, batch_size: int):
""" add current gradients to grad accumulators (if any) """
if self.reuse_grad_buffers:
return # user is responsible for accumulating gradients in .grad buffers
alpha = float(batch_size) / self.batch_size_per_step
for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)

@torch.no_grad()
def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
grad_buf[...] = grad_acc.to(grad_buf.device)
if scale_by is not None:
grad_buf.mul_(scale_by)

@torch.no_grad()
def reset_accumulated_grads_(self):
for grad_buf in self._grad_buffers():
grad_buf.zero_()

def report_training_progress(self):
""" Periodically publish metadata and the current number of samples accumulated towards the next step """
while self.is_alive():
Expand Down Expand Up @@ -235,17 +291,17 @@ def fetch_collaboration_state(self) -> CollaborationState:
if not is_client:
global_optimizer_step = max(global_optimizer_step, opt_step)

total_samples_accumulated = estimated_curent_samples = total_samples_per_second = 0
total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0

for opt_step, samples_accumulated, samples_per_second, timestep, is_client in valid_peer_states:
total_samples_per_second += samples_per_second
if opt_step == global_optimizer_step:
total_samples_accumulated += samples_accumulated
estimated_curent_samples += samples_accumulated + max(0, current_time - timestep) * samples_per_second
estimated_current_samples += samples_accumulated + max(0, current_time - timestep) * samples_per_second
# note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
# the rationale behind this is that outdated peers will synchronize and begin contributing shortly.

estimated_samples_remaining = self.target_batch_size - estimated_curent_samples
estimated_samples_remaining = self.target_batch_size - estimated_current_samples
estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second

expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
Expand Down