Skip to content

Commit

Permalink
Add ProgressTracker (#408)
Browse files Browse the repository at this point in the history
Auxiliary class that keeps track of local & global training progress, measured in epochs.
An epoch can be incremented after collaboration accumulates a said number of gradients (target_batch_size).
Similarly to pytorch LR scheduler, epoch can be incremented on a single optimizer update or many local updates.

Co-authored-by: Anton Sinitsin <[email protected]>
Co-authored-by: Alexander Borzunov <[email protected]>
  • Loading branch information
3 people authored Nov 16, 2021
1 parent 99a0c18 commit d883387
Show file tree
Hide file tree
Showing 2 changed files with 430 additions and 0 deletions.
321 changes: 321 additions & 0 deletions hivemind/optim/experimental/progress_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
import asyncio
import contextlib
import logging
import threading
from dataclasses import dataclass
from typing import Dict, Optional

import numpy as np
from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint

from hivemind.dht import DHT
from hivemind.dht.schema import BytesWithPublicKey, RSASignatureValidator, SchemaValidator
from hivemind.utils import DHTExpiration, ValueWithExpiration, enter_asynchronously, get_dht_time, get_logger
from hivemind.utils.crypto import RSAPrivateKey
from hivemind.utils.performance_ema import PerformanceEMA

logger = get_logger(__name__)


@dataclass(frozen=False)
class GlobalTrainingProgress:
epoch: int
samples_accumulated: int
target_batch_size: int
num_peers: int
num_clients: int
eta_next_epoch: float
next_fetch_time: float


class LocalTrainingProgress(BaseModel):
peer_id: bytes
epoch: conint(ge=0, strict=True)
samples_accumulated: conint(ge=0, strict=True)
samples_per_second: confloat(ge=0.0, strict=True)
time: StrictFloat
client_mode: StrictBool


class TrainingProgressSchema(BaseModel):
progress: Dict[BytesWithPublicKey, Optional[LocalTrainingProgress]]


class ProgressTracker(threading.Thread):
"""
Auxiliary class that keeps track of local & global training progress, measured in epochs.
An epoch can be incremented after collaboration accumulates a said number of gradients (target_batch_size).
Similarly to pytorch LR scheduler, epoch can be incremented on a single optimizer update or many local updates.
:param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
:param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
:param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
:param expected_drift_peers: assume that this many new peers can join between epochs
:param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between epochs
:note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
refresh the collaboration-wide statistics (to avoid missing the moment when peers transition to the next epoch)
:param performance_ema_alpha: smoothing value used to estimate this peer's performance (samples per second)
:param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
Example:
>>> tracker = ProgressTracker(hivemind.DHT(...), prefix="my_experiment_with_several_peers", target_batch_size=100)
>>> local_epoch, local_samples = 0, 0
>>> while True:
>>> accumulate_gradients(batch_size=32)
>>> local_samples += 32
>>> tracker.report_local_progress(local_epoch, local_samples)
>>> if local_epoch < tracker.global_progress.epoch:
>>> download_state_from_peers() # if peer is out of sync, synchronize it with the swarm
>>> if tracker.accumulated_enough_samples:
>>> with tracker.pause_updates():
>>> aggregate_gradients_with_peers()
>>> update_model_parameters()
>>> local_epoch = tracker.update_epoch(local_epoch + 1)
>>> local_samples = 0
"""

def __init__(
self,
dht: DHT,
prefix: str,
target_batch_size: int,
*,
client_mode: Optional[bool] = None,
min_refresh_period: float = 0.5,
max_refresh_period: float = 30,
default_refresh_period: float = 3,
expected_drift_peers: float = 3,
expected_drift_rate: float = 0.2,
performance_ema_alpha: float = 0.1,
metadata_expiration: float = 30.0,
status_loglevel: int = logging.DEBUG,
private_key: Optional[RSAPrivateKey] = None,
daemon: bool = True,
start: bool,
):
client_mode = client_mode if client_mode is not None else dht.client_mode
self.dht, self.prefix, self.client_mode = dht, prefix, client_mode
self.training_progress_key = f"{self.prefix}_progress"
self.target_batch_size = target_batch_size
self.min_refresh_period, self.max_refresh_period = min_refresh_period, max_refresh_period
self.default_refresh_period = default_refresh_period
self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
self.status_loglevel = status_loglevel
self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
self.metadata_expiration = metadata_expiration

signature_validator = RSASignatureValidator(private_key)
self._local_public_key = signature_validator.local_public_key
dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator])

# report the collaboration progress periodically or in background
self.local_progress = self._get_local_progress(local_epoch=0, samples_accumulated=0)
metadata, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
self.global_progress = self._parse_swarm_progress_data(metadata)
self.lock_global_progress, self.global_state_updated = threading.Lock(), threading.Event()
self.should_report_progress = threading.Event()
self.shutdown_triggered, self.shutdown_complete = threading.Event(), threading.Event()
super().__init__(name=f"{self.__class__.__name__}({self.prefix})", daemon=daemon)
if start:
self.start()

@property
def global_epoch(self) -> int:
return self.global_progress.epoch

@property
def ready_to_update_epoch(self) -> bool:
"""Whether or not this peer can increment epoch right away."""
return (
self.global_epoch > self.local_progress.epoch
or self.global_progress.samples_accumulated >= self.target_batch_size
or get_dht_time() >= self.global_progress.eta_next_epoch
)

@property
def estimated_next_update_time(self) -> DHTExpiration:
"""Estimate (absolute) time when this peer should increment epoch"""
if self.ready_to_update_epoch:
return get_dht_time()
return self.global_progress.eta_next_epoch

def _get_local_progress(self, local_epoch: int, samples_accumulated: int):
return LocalTrainingProgress(
peer_id=self.dht.peer_id.to_bytes(),
epoch=local_epoch,
samples_accumulated=samples_accumulated,
samples_per_second=self.performance_ema.samples_per_second,
time=get_dht_time(),
client_mode=self.client_mode,
)

def report_local_progress(self, local_epoch: int, samples_accumulated: int):
"""Update the number of locally accumulated samples and notify to other peers about this."""
extra_samples = samples_accumulated - self.local_progress.samples_accumulated
if extra_samples > 0:
self.performance_ema.update(task_size=extra_samples)
logger.debug(f"Updated performance EMA: {self.performance_ema.samples_per_second:.5f}")
else:
logger.debug("Resetting performance timestamp to current time (progress was reset)")
self.performance_ema.reset_timer()
self.local_progress = self._get_local_progress(local_epoch, samples_accumulated)
self.should_report_progress.set()

@contextlib.contextmanager
def pause_updates(self):
"""Temporarily stop progress tracker from updating global training state"""
with self.lock_global_progress, self.performance_ema.pause():
yield

def update_epoch(self, new_epoch: Optional[int] = None) -> int:
"""Update the local epoch, reset the number of sample accumulated, reset local progress, return new epoch"""
assert self.lock_global_progress.locked(), "ProgressTracker must be paused when incrementing epoch"
if new_epoch is None:
new_epoch = self.local_progress.epoch + 1
if new_epoch > self.global_progress.epoch:
self.global_progress.epoch = new_epoch
self.global_progress.samples_accumulated = 0
self.global_progress.eta_next_epoch = float("inf")
self.report_local_progress(new_epoch, samples_accumulated=0)
return new_epoch

def run(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(asyncio.gather(self._progress_reporter(), self._progress_fetcher()))
self.shutdown_complete.set()

async def _progress_reporter(self):
"""Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
last_report_time = -float("inf")
try:
while not self.shutdown_triggered.is_set():
wait_timeout = max(0.0, last_report_time + self.metadata_expiration - get_dht_time())
logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command")
await asyncio.get_event_loop().run_in_executor(None, self.should_report_progress.wait, wait_timeout)
if self.should_report_progress.is_set():
logger.debug(f"Progress update triggered by report_local_progress.")
self.should_report_progress.clear()
else:
logger.debug(f"Progress update triggered by metadata_expiration.")

local_progress = self.local_progress
last_report_time = get_dht_time()

await self.dht.store(
key=self.training_progress_key,
subkey=self._local_public_key,
value=local_progress.dict(),
expiration_time=last_report_time + self.metadata_expiration,
return_future=True,
)
finally:
logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}.")

async def _progress_fetcher(self):
"""
Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
"""
loop = asyncio.get_event_loop()
try:
while not self.shutdown_triggered.is_set():
time_to_next_update = max(0.0, self.global_progress.next_fetch_time - get_dht_time())
state_updated_externally = await loop.run_in_executor(
None, self.global_state_updated.wait, time_to_next_update
)
if state_updated_externally:
self.global_state_updated.clear()
continue

async with enter_asynchronously(self.lock_global_progress):
progress_entry = await self.dht.get(self.training_progress_key, latest=True, return_future=True)
metadata = progress_entry.value if isinstance(progress_entry, ValueWithExpiration) else None
self.global_progress = self._parse_swarm_progress_data(metadata)
finally:
logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}.")

def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> GlobalTrainingProgress:
"""Read performance statistics reported by peers, estimate progress towards next batch"""
current_time = get_dht_time()

if not isinstance(metadata, dict) or len(metadata) == 0:
logger.log(self.status_loglevel, f"Found no active peers: {metadata}")
samples_remaining_to_next_epoch = max(0, self.target_batch_size - self.local_progress.samples_accumulated)
local_eta_next_epoch = samples_remaining_to_next_epoch / self.performance_ema.samples_per_second

return GlobalTrainingProgress(
self.local_progress.epoch,
self.local_progress.samples_accumulated,
self.target_batch_size,
num_peers=0,
num_clients=0,
eta_next_epoch=current_time + local_eta_next_epoch,
next_fetch_time=current_time + self.default_refresh_period,
)

valid_peer_entries = [
LocalTrainingProgress.parse_obj(peer_state.value)
for peer_state in metadata.values()
if peer_state.value is not None
]

num_peers = len(valid_peer_entries)
num_clients = sum(peer.client_mode for peer in valid_peer_entries)

global_epoch = self.local_progress.epoch
for peer in valid_peer_entries:
if not peer.client_mode:
global_epoch = max(global_epoch, peer.epoch)

total_samples_accumulated = estimated_current_samples = 0
total_samples_per_second = self.performance_ema.eps

for peer in valid_peer_entries:
total_samples_per_second += peer.samples_per_second
if peer.epoch == global_epoch:
total_samples_accumulated += peer.samples_accumulated
estimated_current_samples += (
peer.samples_accumulated + max(0.0, current_time - peer.time) * peer.samples_per_second
)
# note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
# the rationale behind this is that outdated peers will synchronize and begin contributing shortly.

estimated_samples_remaining = self.target_batch_size - estimated_current_samples
estimated_time_to_next_epoch = max(0, estimated_samples_remaining) / total_samples_per_second

expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
time_to_next_fetch = float(
np.clip(
a=estimated_time_to_next_epoch * num_peers / expected_max_peers,
a_min=self.min_refresh_period,
a_max=self.max_refresh_period,
)
)
logger.log(
self.status_loglevel,
f"{self.prefix} accumulated {total_samples_accumulated} samples for iteration #{global_epoch} from "
f"{num_peers} peers. ETA {estimated_time_to_next_epoch:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
)
return GlobalTrainingProgress(
global_epoch,
total_samples_accumulated,
target_batch_size=self.target_batch_size,
num_peers=num_peers,
num_clients=num_clients,
eta_next_epoch=current_time + estimated_time_to_next_epoch,
next_fetch_time=current_time + time_to_next_fetch,
)

def shutdown(self):
"""Permanently disable all tracking activity"""
self.shutdown_triggered.set()
self.should_report_progress.set()
self.global_state_updated.set()
self.shutdown_complete.wait()
self.dht.store(
self.training_progress_key,
subkey=self._local_public_key,
value=None,
expiration_time=get_dht_time() + self.metadata_expiration,
)
Loading

0 comments on commit d883387

Please sign in to comment.