-
Notifications
You must be signed in to change notification settings - Fork 167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[hivemind.Optimizer] ProgressTracker #408
Conversation
Codecov Report
@@ Coverage Diff @@
## master #408 +/- ##
==========================================
+ Coverage 84.17% 84.57% +0.40%
==========================================
Files 75 76 +1
Lines 7121 7280 +159
==========================================
+ Hits 5994 6157 +163
+ Misses 1127 1123 -4
|
@contextlib.contextmanager | ||
def pause_updates(self): | ||
"""Temporarily stop progress tracker from updating global training state""" | ||
with self.lock_global_progress, self.performance_ema.pause(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(may be out of scope of this PR)
I'm surprised that performance_ema.pause()
resets the timer instead of actually pausing it.
Co-authored-by: Alexander Borzunov <[email protected]>
try: | ||
while not self.shutdown_triggered.is_set(): | ||
wait_timeout = max(0.0, last_report_time + self.metadata_expiration - get_dht_time()) | ||
logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Here and further:
logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command.") | |
logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command") |
if self.global_epoch > self.local_progress.epoch: | ||
return True | ||
elif self.global_progress.samples_accumulated >= self.target_batch_size: | ||
return True | ||
elif get_dht_time() >= self.global_progress.eta_next_epoch: | ||
return True | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.global_epoch > self.local_progress.epoch: | |
return True | |
elif self.global_progress.samples_accumulated >= self.target_batch_size: | |
return True | |
elif get_dht_time() >= self.global_progress.eta_next_epoch: | |
return True | |
return False | |
return ( | |
self.global_epoch > self.local_progress.epoch or | |
self.global_progress.samples_accumulated >= self.target_batch_size or | |
get_dht_time() >= self.global_progress.eta_next_epoch | |
) |
tests/test_optimizer.py
Outdated
assert not tracker.is_alive() | ||
|
||
mean_step_time = sum(step_time_deltas) / len(step_time_deltas) | ||
for i in (0, 1, 5): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for i in (0, 1, 5): | |
for i in (0, 1, 5): # Without the 4th worker (the fastest one) |
tests/test_optimizer.py
Outdated
mean_step_time = sum(step_time_deltas) / len(step_time_deltas) | ||
for i in (0, 1, 5): | ||
assert 1.05 * mean_step_time < step_time_deltas[i] < 2.0 * mean_step_time | ||
for i in (2, 3, 4): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for i in (2, 3, 4): | |
for i in (2, 3, 4): # With the 4th worker |
performance_ema_alpha: float = 0.1, | ||
metadata_expiration: float = 30.0, | ||
status_loglevel: int = logging.DEBUG, | ||
private_key: PrivateKey = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
private_key: PrivateKey = None, | |
private_key: Optional[RSAPrivateKey] = None, |
tests/test_optimizer.py
Outdated
tracker.shutdown() | ||
dht.shutdown() | ||
|
||
# note: we use processes here because RSASignatureValidator inside trackers uses process-wide RSA keys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# note: we use processes here because RSASignatureValidator inside trackers uses process-wide RSA keys |
Auxiliary class that keeps track of local & global training progress, measured in epochs.
An epoch can be incremented after collaboration accumulates a said number of gradients (target_batch_size).
Similarly to pytorch LR scheduler, epoch can be incremented on a single optimizer update or many local updates.