-
Notifications
You must be signed in to change notification settings - Fork 167
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Auxiliary class that keeps track of local & global training progress, measured in epochs. An epoch can be incremented after collaboration accumulates a said number of gradients (target_batch_size). Similarly to pytorch LR scheduler, epoch can be incremented on a single optimizer update or many local updates. Co-authored-by: Anton Sinitsin <[email protected]> Co-authored-by: Alexander Borzunov <[email protected]>
- Loading branch information
1 parent
99a0c18
commit d883387
Showing
2 changed files
with
430 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,321 @@ | ||
import asyncio | ||
import contextlib | ||
import logging | ||
import threading | ||
from dataclasses import dataclass | ||
from typing import Dict, Optional | ||
|
||
import numpy as np | ||
from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint | ||
|
||
from hivemind.dht import DHT | ||
from hivemind.dht.schema import BytesWithPublicKey, RSASignatureValidator, SchemaValidator | ||
from hivemind.utils import DHTExpiration, ValueWithExpiration, enter_asynchronously, get_dht_time, get_logger | ||
from hivemind.utils.crypto import RSAPrivateKey | ||
from hivemind.utils.performance_ema import PerformanceEMA | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
@dataclass(frozen=False) | ||
class GlobalTrainingProgress: | ||
epoch: int | ||
samples_accumulated: int | ||
target_batch_size: int | ||
num_peers: int | ||
num_clients: int | ||
eta_next_epoch: float | ||
next_fetch_time: float | ||
|
||
|
||
class LocalTrainingProgress(BaseModel): | ||
peer_id: bytes | ||
epoch: conint(ge=0, strict=True) | ||
samples_accumulated: conint(ge=0, strict=True) | ||
samples_per_second: confloat(ge=0.0, strict=True) | ||
time: StrictFloat | ||
client_mode: StrictBool | ||
|
||
|
||
class TrainingProgressSchema(BaseModel): | ||
progress: Dict[BytesWithPublicKey, Optional[LocalTrainingProgress]] | ||
|
||
|
||
class ProgressTracker(threading.Thread): | ||
""" | ||
Auxiliary class that keeps track of local & global training progress, measured in epochs. | ||
An epoch can be incremented after collaboration accumulates a said number of gradients (target_batch_size). | ||
Similarly to pytorch LR scheduler, epoch can be incremented on a single optimizer update or many local updates. | ||
:param min_refresh_period: wait for at least this many seconds before fetching new collaboration state | ||
:param max_refresh_period: wait for at most this many seconds before fetching new collaboration state | ||
:param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds) | ||
:param expected_drift_peers: assume that this many new peers can join between epochs | ||
:param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between epochs | ||
:note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will | ||
refresh the collaboration-wide statistics (to avoid missing the moment when peers transition to the next epoch) | ||
:param performance_ema_alpha: smoothing value used to estimate this peer's performance (samples per second) | ||
:param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds | ||
Example: | ||
>>> tracker = ProgressTracker(hivemind.DHT(...), prefix="my_experiment_with_several_peers", target_batch_size=100) | ||
>>> local_epoch, local_samples = 0, 0 | ||
>>> while True: | ||
>>> accumulate_gradients(batch_size=32) | ||
>>> local_samples += 32 | ||
>>> tracker.report_local_progress(local_epoch, local_samples) | ||
>>> if local_epoch < tracker.global_progress.epoch: | ||
>>> download_state_from_peers() # if peer is out of sync, synchronize it with the swarm | ||
>>> if tracker.accumulated_enough_samples: | ||
>>> with tracker.pause_updates(): | ||
>>> aggregate_gradients_with_peers() | ||
>>> update_model_parameters() | ||
>>> local_epoch = tracker.update_epoch(local_epoch + 1) | ||
>>> local_samples = 0 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dht: DHT, | ||
prefix: str, | ||
target_batch_size: int, | ||
*, | ||
client_mode: Optional[bool] = None, | ||
min_refresh_period: float = 0.5, | ||
max_refresh_period: float = 30, | ||
default_refresh_period: float = 3, | ||
expected_drift_peers: float = 3, | ||
expected_drift_rate: float = 0.2, | ||
performance_ema_alpha: float = 0.1, | ||
metadata_expiration: float = 30.0, | ||
status_loglevel: int = logging.DEBUG, | ||
private_key: Optional[RSAPrivateKey] = None, | ||
daemon: bool = True, | ||
start: bool, | ||
): | ||
client_mode = client_mode if client_mode is not None else dht.client_mode | ||
self.dht, self.prefix, self.client_mode = dht, prefix, client_mode | ||
self.training_progress_key = f"{self.prefix}_progress" | ||
self.target_batch_size = target_batch_size | ||
self.min_refresh_period, self.max_refresh_period = min_refresh_period, max_refresh_period | ||
self.default_refresh_period = default_refresh_period | ||
self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate | ||
self.status_loglevel = status_loglevel | ||
self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha) | ||
self.metadata_expiration = metadata_expiration | ||
|
||
signature_validator = RSASignatureValidator(private_key) | ||
self._local_public_key = signature_validator.local_public_key | ||
dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator]) | ||
|
||
# report the collaboration progress periodically or in background | ||
self.local_progress = self._get_local_progress(local_epoch=0, samples_accumulated=0) | ||
metadata, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf")) | ||
self.global_progress = self._parse_swarm_progress_data(metadata) | ||
self.lock_global_progress, self.global_state_updated = threading.Lock(), threading.Event() | ||
self.should_report_progress = threading.Event() | ||
self.shutdown_triggered, self.shutdown_complete = threading.Event(), threading.Event() | ||
super().__init__(name=f"{self.__class__.__name__}({self.prefix})", daemon=daemon) | ||
if start: | ||
self.start() | ||
|
||
@property | ||
def global_epoch(self) -> int: | ||
return self.global_progress.epoch | ||
|
||
@property | ||
def ready_to_update_epoch(self) -> bool: | ||
"""Whether or not this peer can increment epoch right away.""" | ||
return ( | ||
self.global_epoch > self.local_progress.epoch | ||
or self.global_progress.samples_accumulated >= self.target_batch_size | ||
or get_dht_time() >= self.global_progress.eta_next_epoch | ||
) | ||
|
||
@property | ||
def estimated_next_update_time(self) -> DHTExpiration: | ||
"""Estimate (absolute) time when this peer should increment epoch""" | ||
if self.ready_to_update_epoch: | ||
return get_dht_time() | ||
return self.global_progress.eta_next_epoch | ||
|
||
def _get_local_progress(self, local_epoch: int, samples_accumulated: int): | ||
return LocalTrainingProgress( | ||
peer_id=self.dht.peer_id.to_bytes(), | ||
epoch=local_epoch, | ||
samples_accumulated=samples_accumulated, | ||
samples_per_second=self.performance_ema.samples_per_second, | ||
time=get_dht_time(), | ||
client_mode=self.client_mode, | ||
) | ||
|
||
def report_local_progress(self, local_epoch: int, samples_accumulated: int): | ||
"""Update the number of locally accumulated samples and notify to other peers about this.""" | ||
extra_samples = samples_accumulated - self.local_progress.samples_accumulated | ||
if extra_samples > 0: | ||
self.performance_ema.update(task_size=extra_samples) | ||
logger.debug(f"Updated performance EMA: {self.performance_ema.samples_per_second:.5f}") | ||
else: | ||
logger.debug("Resetting performance timestamp to current time (progress was reset)") | ||
self.performance_ema.reset_timer() | ||
self.local_progress = self._get_local_progress(local_epoch, samples_accumulated) | ||
self.should_report_progress.set() | ||
|
||
@contextlib.contextmanager | ||
def pause_updates(self): | ||
"""Temporarily stop progress tracker from updating global training state""" | ||
with self.lock_global_progress, self.performance_ema.pause(): | ||
yield | ||
|
||
def update_epoch(self, new_epoch: Optional[int] = None) -> int: | ||
"""Update the local epoch, reset the number of sample accumulated, reset local progress, return new epoch""" | ||
assert self.lock_global_progress.locked(), "ProgressTracker must be paused when incrementing epoch" | ||
if new_epoch is None: | ||
new_epoch = self.local_progress.epoch + 1 | ||
if new_epoch > self.global_progress.epoch: | ||
self.global_progress.epoch = new_epoch | ||
self.global_progress.samples_accumulated = 0 | ||
self.global_progress.eta_next_epoch = float("inf") | ||
self.report_local_progress(new_epoch, samples_accumulated=0) | ||
return new_epoch | ||
|
||
def run(self): | ||
loop = asyncio.new_event_loop() | ||
asyncio.set_event_loop(loop) | ||
loop.run_until_complete(asyncio.gather(self._progress_reporter(), self._progress_fetcher())) | ||
self.shutdown_complete.set() | ||
|
||
async def _progress_reporter(self): | ||
"""Periodically publish metadata and the current number of samples accumulated towards the next epoch""" | ||
last_report_time = -float("inf") | ||
try: | ||
while not self.shutdown_triggered.is_set(): | ||
wait_timeout = max(0.0, last_report_time + self.metadata_expiration - get_dht_time()) | ||
logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command") | ||
await asyncio.get_event_loop().run_in_executor(None, self.should_report_progress.wait, wait_timeout) | ||
if self.should_report_progress.is_set(): | ||
logger.debug(f"Progress update triggered by report_local_progress.") | ||
self.should_report_progress.clear() | ||
else: | ||
logger.debug(f"Progress update triggered by metadata_expiration.") | ||
|
||
local_progress = self.local_progress | ||
last_report_time = get_dht_time() | ||
|
||
await self.dht.store( | ||
key=self.training_progress_key, | ||
subkey=self._local_public_key, | ||
value=local_progress.dict(), | ||
expiration_time=last_report_time + self.metadata_expiration, | ||
return_future=True, | ||
) | ||
finally: | ||
logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}.") | ||
|
||
async def _progress_fetcher(self): | ||
""" | ||
Periodically check the training progress from all peers. Trigger update after target_batch_size total samples | ||
""" | ||
loop = asyncio.get_event_loop() | ||
try: | ||
while not self.shutdown_triggered.is_set(): | ||
time_to_next_update = max(0.0, self.global_progress.next_fetch_time - get_dht_time()) | ||
state_updated_externally = await loop.run_in_executor( | ||
None, self.global_state_updated.wait, time_to_next_update | ||
) | ||
if state_updated_externally: | ||
self.global_state_updated.clear() | ||
continue | ||
|
||
async with enter_asynchronously(self.lock_global_progress): | ||
progress_entry = await self.dht.get(self.training_progress_key, latest=True, return_future=True) | ||
metadata = progress_entry.value if isinstance(progress_entry, ValueWithExpiration) else None | ||
self.global_progress = self._parse_swarm_progress_data(metadata) | ||
finally: | ||
logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}.") | ||
|
||
def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> GlobalTrainingProgress: | ||
"""Read performance statistics reported by peers, estimate progress towards next batch""" | ||
current_time = get_dht_time() | ||
|
||
if not isinstance(metadata, dict) or len(metadata) == 0: | ||
logger.log(self.status_loglevel, f"Found no active peers: {metadata}") | ||
samples_remaining_to_next_epoch = max(0, self.target_batch_size - self.local_progress.samples_accumulated) | ||
local_eta_next_epoch = samples_remaining_to_next_epoch / self.performance_ema.samples_per_second | ||
|
||
return GlobalTrainingProgress( | ||
self.local_progress.epoch, | ||
self.local_progress.samples_accumulated, | ||
self.target_batch_size, | ||
num_peers=0, | ||
num_clients=0, | ||
eta_next_epoch=current_time + local_eta_next_epoch, | ||
next_fetch_time=current_time + self.default_refresh_period, | ||
) | ||
|
||
valid_peer_entries = [ | ||
LocalTrainingProgress.parse_obj(peer_state.value) | ||
for peer_state in metadata.values() | ||
if peer_state.value is not None | ||
] | ||
|
||
num_peers = len(valid_peer_entries) | ||
num_clients = sum(peer.client_mode for peer in valid_peer_entries) | ||
|
||
global_epoch = self.local_progress.epoch | ||
for peer in valid_peer_entries: | ||
if not peer.client_mode: | ||
global_epoch = max(global_epoch, peer.epoch) | ||
|
||
total_samples_accumulated = estimated_current_samples = 0 | ||
total_samples_per_second = self.performance_ema.eps | ||
|
||
for peer in valid_peer_entries: | ||
total_samples_per_second += peer.samples_per_second | ||
if peer.epoch == global_epoch: | ||
total_samples_accumulated += peer.samples_accumulated | ||
estimated_current_samples += ( | ||
peer.samples_accumulated + max(0.0, current_time - peer.time) * peer.samples_per_second | ||
) | ||
# note: we deliberately count only valid peers for samples_accumulated, but all peers for performance; | ||
# the rationale behind this is that outdated peers will synchronize and begin contributing shortly. | ||
|
||
estimated_samples_remaining = self.target_batch_size - estimated_current_samples | ||
estimated_time_to_next_epoch = max(0, estimated_samples_remaining) / total_samples_per_second | ||
|
||
expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate)) | ||
time_to_next_fetch = float( | ||
np.clip( | ||
a=estimated_time_to_next_epoch * num_peers / expected_max_peers, | ||
a_min=self.min_refresh_period, | ||
a_max=self.max_refresh_period, | ||
) | ||
) | ||
logger.log( | ||
self.status_loglevel, | ||
f"{self.prefix} accumulated {total_samples_accumulated} samples for iteration #{global_epoch} from " | ||
f"{num_peers} peers. ETA {estimated_time_to_next_epoch:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)", | ||
) | ||
return GlobalTrainingProgress( | ||
global_epoch, | ||
total_samples_accumulated, | ||
target_batch_size=self.target_batch_size, | ||
num_peers=num_peers, | ||
num_clients=num_clients, | ||
eta_next_epoch=current_time + estimated_time_to_next_epoch, | ||
next_fetch_time=current_time + time_to_next_fetch, | ||
) | ||
|
||
def shutdown(self): | ||
"""Permanently disable all tracking activity""" | ||
self.shutdown_triggered.set() | ||
self.should_report_progress.set() | ||
self.global_state_updated.set() | ||
self.shutdown_complete.wait() | ||
self.dht.store( | ||
self.training_progress_key, | ||
subkey=self._local_public_key, | ||
value=None, | ||
expiration_time=get_dht_time() + self.metadata_expiration, | ||
) |
Oops, something went wrong.