Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify ExpertBackend interface #483

Merged
merged 39 commits into from
Jun 15, 2022
Merged

Simplify ExpertBackend interface #483

merged 39 commits into from
Jun 15, 2022

Conversation

justheuristic
Copy link
Member

@justheuristic justheuristic commented Jun 14, 2022

The core idea is that we should not make hivemind internals conditional on a specific training technique, such as linear warmup or gradient clipping by norm. Instead, we let user define their own scheduler and/or optimizer as necessary.

  • remove gradient clipping from ExpertBackend: this behavior can be achieved with a user-defined Optimizer
  • remove stats from ExpertBackend: this behavior can be achieved with a user-defined Scheduler
  • rename full_state -> state_dict, rationale: there is no "non-full" state in this context
  • rename ExpertBackend.expert -> ExpertBackend.module to avoid confusion

justheuristic and others added 14 commits June 12, 2022 23:05
@@ -54,7 +54,8 @@ def main():
help='Server will report experts to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None,
help='DHT entries will expire after this many seconds')
parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
parser.add_argument('--num_training_steps', type=int, required=False, help='The total number of steps for LR schedule')
Copy link
Member Author

@justheuristic justheuristic Jun 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -163,65 +141,47 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
torch.autograd.backward(
outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False
)
self.apply_gradients(batch_size)
self.on_backward(batch_size)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rationale: this does not necessarily apply gradients, e.g.

  • virtual batching applies gradients once every k steps
  • pretrained models do not apply the returned gradients

@codecov
Copy link

codecov bot commented Jun 14, 2022

Codecov Report

Merging #483 (e05d3dc) into master (6c56a87) will decrease coverage by 0.19%.
The diff coverage is 75.45%.

@@            Coverage Diff             @@
##           master     #483      +/-   ##
==========================================
- Coverage   85.97%   85.78%   -0.20%     
==========================================
  Files          79       80       +1     
  Lines        7772     7808      +36     
==========================================
+ Hits         6682     6698      +16     
- Misses       1090     1110      +20     
Impacted Files Coverage Δ
hivemind/__init__.py 100.00% <ø> (ø)
hivemind/moe/__init__.py 100.00% <ø> (ø)
hivemind/moe/server/layers/dropout.py 96.87% <ø> (ø)
hivemind/moe/server/layers/optim.py 52.63% <52.63%> (ø)
hivemind/moe/server/checkpoints.py 82.45% <66.66%> (ø)
hivemind/moe/server/dht_handler.py 98.24% <75.00%> (ø)
hivemind/moe/server/server.py 79.67% <85.71%> (-0.44%) ⬇️
hivemind/moe/server/module_backend.py 94.66% <93.10%> (ø)
hivemind/hivemind_cli/run_server.py 80.32% <100.00%> (ø)
hivemind/moe/server/__init__.py 100.00% <100.00%> (ø)
... and 7 more

@justheuristic justheuristic marked this pull request as ready for review June 14, 2022 04:34
@justheuristic justheuristic requested a review from mryab June 14, 2022 04:36
optimizer = optim_cls(expert.parameters()) if optim_cls is not None else None
scheduler = scheduler_cls(optimizer) if scheduler_cls is not None else None
if clip_grad_norm is not None:
scheduler = ClippingWrapper(scheduler, clip_grad_norm)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO FIX

hivemind/moe/server/layers/optim.py Outdated Show resolved Hide resolved

import torch
from torch import nn

from hivemind.moe.server.task_pool import TaskPool
from hivemind.optim.state_averager import LRSchedulerBase
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this import is not actually related to hivemind.optim, I’d suggest so simply inline the statement that is being imported

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;

.. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not correct anymore? :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is arguably no different than in gradient checkpoints, and it is unlikely that we can fix it here for all cases -- without user defining custom layers. I can keep it if you insist [please reply here if so]. Alternatively, perhaps it would be best to change this todo into a warning/note. Your call?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A warning would be sufficient, I suppose

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restored the warning

hivemind/moe/server/expert_backend.py Show resolved Hide resolved
hivemind/moe/server/expert_backend.py Show resolved Hide resolved
"""A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer"""

def __init__(self, optim: torch.optim.Optimizer):
object.__init__(self)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

object? In that case, it’s not a true Optimizer

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the contrary it defines all optimizer fields and methods as properties

return self.optim.add_param_group(param_group)


class ClippingWrapper(OptimizerWrapper):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case we need OptimizerWrapper just for this application, I’d suggest not to overcomplicate the code and just write one class for a specific use case

hivemind/moe/server/server.py Outdated Show resolved Hide resolved
hivemind/moe/server/server.py Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants