Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement CenteredClip in averager #379

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions hivemind/averaging/accumulators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import dataclasses
from abc import ABC
from typing import Callable, Optional

import torch


class AccumulatorBase(ABC):
def accumulate_part(self, tensor: torch.Tensor, weight: float) -> None:
...

def reduce(self) -> torch.Tensor:
...


AccumulatorFactory = Callable[[torch.Size, int], AccumulatorBase]


class MeanAccumulator(AccumulatorBase):
def __init__(self, part_shape: torch.Size, _n_peers: int):
self._accumulator = torch.zeros(part_shape)
self._denominator = 0.0

def accumulate_part(self, tensor_part: torch.Tensor, weight: float) -> None:
self._accumulator.add_(tensor_part, alpha=weight)
self._denominator += weight

def reduce(self) -> torch.Tensor:
return self._accumulator.div_(self._denominator)


class CenteredClipAccumulator(AccumulatorBase):
def __init__(self, part_shape: torch.Size, n_peers: int, **kwargs):
self._kwargs = kwargs

self._tensors = torch.empty([n_peers] + part_shape)
self._weights = torch.empty(n_peers)
self._index = 0

def accumulate_part(self, tensor_part: torch.Tensor, weight: float) -> None:
self._tensors[self._index] = tensor_part
self._weights[self._index] = weight
self._index += 1

def reduce(self) -> torch.Tensor:
clipped = centered_clip(self._tensors, self._weights, **self._kwargs)
return clipped.result


@dataclasses.dataclass(frozen=True)
class CenteredClipResult:
result: torch.Tensor
n_clipped: torch.Tensor
last_step_delta: torch.Tensor


def centered_clip(
input_tensors: torch.Tensor,
weights: torch.Tensor,
tau: float = 1.0,
n_iters: int = 20,
stop_delta: Optional[float] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preference: let's default to some reasonable delta and a very large n steps

) -> CenteredClipResult:
"""
Optimized implementation of CenteredClip from [Karimireddy, 2021].
Intended to be used in a decentralized fashion as in [Gorbunov, 2021].

:stop_delta: Stop iterations early if the ``L_inf`` norm of the last step is less than ``stop_delta``.
Note: if this option is used, the step norm calculations may increase the time per iteration by ~25%.

References:

[Karimireddy, 2021] Karimireddy, Sai Praneeth, Lie He, and Martin Jaggi. "Learning from history for byzantine
robust optimization." International Conference on Machine Learning. PMLR, 2021.

[Gorbunov, 2021] Gorbunov, Eduard, Alexander Borzunov, Michael Diskin, and Max Ryabinin.
"Secure Distributed Training at Scale." arXiv preprint arXiv:2106.11257 (2021).
"""

with torch.no_grad():
n_peers = input_tensors.shape[0]
result_shape = input_tensors.shape[1:]

input_tensors = input_tensors.flatten(start_dim=1)
weights /= weights.sum()

# This finds medians faster than torch.median() and torch.quantile(q=0.5),
# see https://github.com/pytorch/pytorch/issues/51450
sorted_tensors = input_tensors.sort(dim=0).values
result = sorted_tensors[n_peers // 2].clone()
delta = None

diff = torch.sub(input_tensors, result, out=sorted_tensors) # Reuse memory from `sorted_tensors`
for _ in range(n_iters):
norms = diff.norm(dim=1)
coeffs = weights * torch.minimum(torch.tensor(1.0), tau / norms)

if stop_delta is not None:
result[...] = diff[0] # Reuse memory from `result`
prev_diff = result

# We only need to update `diff` (not `result`) between iterations
diff.addmm_(-coeffs.repeat(n_peers, 1), diff)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
diff.addmm_(-coeffs.repeat(n_peers, 1), diff)
diff -= coeffs @ diff

It seems like addmm_() doesn't work correctly if the destination is equal to one of the operands.


if stop_delta is not None:
delta = prev_diff.sub_(diff[0]).abs().max()
if delta < stop_delta:
break
torch.sub(input_tensors[0], diff[0], out=result)

return CenteredClipResult(
result=result.reshape(result_shape), n_clipped=(tau < norms).sum(), last_step_delta=delta
)
5 changes: 4 additions & 1 deletion hivemind/averaging/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

from hivemind.averaging.accumulators import AccumulatorFactory
from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(
tensors: Sequence[torch.Tensor],
ordered_peer_ids: Sequence[PeerID],
peer_fractions: Tuple[float, ...],
accumulator_factory: AccumulatorFactory,
weights: Optional[Sequence[float]] = None,
modes: Optional[Sequence[AveragingMode]] = None,
gathered: Optional[Dict[PeerID, Any]] = None,
Expand Down Expand Up @@ -97,7 +99,8 @@ def __init__(
self.tensor_part_reducer = TensorPartReducer(
tuple(part.shape for part in self.parts_for_local_averaging),
len(self.sender_peer_ids),
self.sender_weights,
weights=self.sender_weights,
accumulator_factory=accumulator_factory,
)

def __repr__(self):
Expand Down
3 changes: 3 additions & 0 deletions hivemind/averaging/averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import torch

from hivemind.averaging.accumulators import AccumulatorFactory, MeanAccumulator
from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
from hivemind.averaging.group_info import GroupInfo
from hivemind.averaging.load_balancing import load_balance_peers
Expand Down Expand Up @@ -112,6 +113,7 @@ def __init__(
compression: CompressionBase = NoCompression(),
state_compression: CompressionBase = NoCompression(),
tensor_infos: Optional[Sequence[CompressionInfo]] = None,
accumulator_factory: AccumulatorFactory = MeanAccumulator,
bandwidth: Optional[float] = None,
min_vector_size: int = 0,
auxiliary: bool = False,
Expand Down Expand Up @@ -170,6 +172,7 @@ def __init__(
compression=compression,
part_size_bytes=part_size_bytes,
min_vector_size=min_vector_size,
accumulator_factory=accumulator_factory,
)
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
Expand Down
24 changes: 15 additions & 9 deletions hivemind/averaging/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import torch

from hivemind.averaging.accumulators import AccumulatorFactory
from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import amap_in_executor
Expand Down Expand Up @@ -171,16 +172,23 @@ class TensorPartReducer:
:note: even if local peer is not sending data, local parts will be used for shape information
"""

def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None):
def __init__(
self,
part_shapes: Sequence[torch.Size],
num_senders: int,
*,
weights: Optional[Sequence[float]],
accumulator_factory: AccumulatorFactory,
):
self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
self.weights = tuple(weights or (1 for _ in range(num_senders)))
assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
assert all(isinstance(weight, (int, float)) for weight in self.weights)
self.current_part_index = -1 # index in local_parts of the part that should be loaded next
self.current_part_accumulated_from = 0 # number of peers from which the current part was accumulated
self.accumulator = None # this will contain the sum of current tensor part from group peers
self.denominator = 0.0 # total weight accumulated from all peers for current part
self.current_part_future = asyncio.Future()
self.accumulator_factory = accumulator_factory
self.accumulator = None
self.finished = asyncio.Event()
self.reset_accumulators()

Expand All @@ -194,8 +202,7 @@ def reset_accumulators(self):
self.current_part_index += 1
self.current_part_accumulated_from = 0
self.current_part_future = asyncio.Future()
self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
self.denominator = 0.0
self.accumulator = self.accumulator_factory(self.part_shapes[self.current_part_index], self.num_senders)

async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
"""Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
Expand All @@ -211,21 +218,20 @@ async def accumulate_part(self, sender_index: int, part_index: int, tensor_part:

current_part_future = self.current_part_future

self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
self.denominator += self.weights[sender_index]
self.accumulator.accumulate_part(tensor_part, self.weights[sender_index])
self.current_part_accumulated_from += 1

assert self.current_part_accumulated_from <= self.num_senders
if self.current_part_accumulated_from == self.num_senders:
current_part_future.set_result(self.accumulator.div_(self.denominator))
current_part_future.set_result(self.accumulator.reduce())
self.reset_accumulators()
return await current_part_future

def finalize(self):
if not self.finished.is_set():
if hasattr(self, "current_part_future"):
self.current_part_future.cancel()
del self.accumulator
self.accumulator = None
self.finished.set()

def __del__(self):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from hivemind import Quantile8BitQuantization, aenumerate
from hivemind.averaging.accumulators import MeanAccumulator
from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
from hivemind.compression import deserialize_torch_tensor
Expand Down Expand Up @@ -119,7 +120,7 @@ async def wait_synchronously():
@pytest.mark.asyncio
async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float):
tensor_part_shapes = [torch.Size([i]) for i in range(num_parts)]
reducer = TensorPartReducer(tensor_part_shapes, num_senders)
reducer = TensorPartReducer(tensor_part_shapes, num_senders, weights=None, accumulator_factory=MeanAccumulator)

local_tensors_by_sender = [[torch.randn(i) for i in range(num_parts)] for j in range(num_senders)]

Expand Down Expand Up @@ -196,6 +197,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
tensors=[x.clone() for x in tensors_by_peer[p2p.peer_id]],
ordered_peer_ids=peers,
peer_fractions=peer_fractions,
accumulator_factory=MeanAccumulator,
modes=peer_modes,
weights=averaging_weights,
part_size_bytes=part_size_bytes,
Expand Down