-
Notifications
You must be signed in to change notification settings - Fork 167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify ExpertBackend interface #483
Conversation
- extract batching and clipping from ExpertBackend, reassign this role to optimizer/scheduler - rename full_state -> state_dict, rationale: there is no "non-full" state in this context - rename ExpertBackend.expert -> ExpertBackend.module to avoid confusion
@@ -54,7 +54,8 @@ def main(): | |||
help='Server will report experts to DHT once in this many seconds') | |||
parser.add_argument('--expiration', type=float, required=False, default=None, | |||
help='DHT entries will expire after this many seconds') | |||
parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule') | |||
parser.add_argument('--num_training_steps', type=int, required=False, help='The total number of steps for LR schedule') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Name changed in the source: https://github.com/huggingface/transformers/blob/v4.19.4/src/transformers/optimization.py#L75
@@ -163,65 +141,47 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: | |||
torch.autograd.backward( | |||
outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False | |||
) | |||
self.apply_gradients(batch_size) | |||
self.on_backward(batch_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rationale: this does not necessarily apply gradients, e.g.
- virtual batching applies gradients once every k steps
- pretrained models do not apply the returned gradients
Codecov Report
@@ Coverage Diff @@
## master #483 +/- ##
==========================================
- Coverage 85.97% 85.78% -0.20%
==========================================
Files 79 80 +1
Lines 7772 7808 +36
==========================================
+ Hits 6682 6698 +16
- Misses 1090 1110 +20
|
hivemind/moe/server/server.py
Outdated
optimizer = optim_cls(expert.parameters()) if optim_cls is not None else None | ||
scheduler = scheduler_cls(optimizer) if scheduler_cls is not None else None | ||
if clip_grad_norm is not None: | ||
scheduler = ClippingWrapper(scheduler, clip_grad_norm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO FIX
|
||
import torch | ||
from torch import nn | ||
|
||
from hivemind.moe.server.task_pool import TaskPool | ||
from hivemind.optim.state_averager import LRSchedulerBase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this import is not actually related to hivemind.optim, I’d suggest so simply inline the statement that is being imported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``; | ||
|
||
.. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this not correct anymore? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is arguably no different than in gradient checkpoints, and it is unlikely that we can fix it here for all cases -- without user defining custom layers. I can keep it if you insist [please reply here if so]. Alternatively, perhaps it would be best to change this todo into a warning/note. Your call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A warning would be sufficient, I suppose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
restored the warning
hivemind/moe/server/layers/optim.py
Outdated
"""A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer""" | ||
|
||
def __init__(self, optim: torch.optim.Optimizer): | ||
object.__init__(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
object? In that case, it’s not a true Optimizer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on the contrary it defines all optimizer fields and methods as properties
return self.optim.add_param_group(param_group) | ||
|
||
|
||
class ClippingWrapper(OptimizerWrapper): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case we need OptimizerWrapper just for this application, I’d suggest not to overcomplicate the code and just write one class for a specific use case
Co-authored-by: Max Ryabinin <[email protected]>
Co-authored-by: Max Ryabinin <[email protected]>
The core idea is that we should not make hivemind internals conditional on a specific training technique, such as linear warmup or gradient clipping by norm. Instead, we let user define their own scheduler and/or optimizer as necessary.