-
Notifications
You must be signed in to change notification settings - Fork 167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify ExpertBackend interface #483
Changes from 16 commits
e5a3e46
637fb01
372d915
98d4952
b43c243
108a24b
a5aa6f9
e940911
2356bc1
771ebc1
8fb0986
c210ecb
ec22eda
a0622bd
fa2da45
6c49fe9
2c77de0
b1873e1
9f3187f
d87a7b1
65d622b
a00fb9e
5569c42
add83b5
9664d05
c30bc6c
7aed0a8
fa48f2c
5601a95
decf1b4
409e035
891a83b
dd6fc94
a9b7643
8a2e1f2
efeb31b
a63e8ef
04af589
e05d3dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
from typing import Any, Callable, Dict, Sequence, Tuple, Union | ||
from typing import Any, Dict, Optional, Sequence, Tuple, Union | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from hivemind.moe.server.task_pool import TaskPool | ||
from hivemind.optim.state_averager import LRSchedulerBase | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this import is not actually related to hivemind.optim, I’d suggest so simply inline the statement that is being imported There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
from hivemind.utils.logging import get_logger | ||
from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack | ||
from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor | ||
|
@@ -20,7 +21,7 @@ class ExpertBackend: | |
- backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and **update expert**. Also batched. | ||
- get_info - return expert metadata. Not batched. | ||
|
||
:param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations: | ||
:param module: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations: | ||
|
||
- Experts must always receive the same set of args and kwargs and produce output tensors of same type | ||
- All args, kwargs and outputs must be **tensors** where 0-th dimension represents to batch size | ||
|
@@ -34,49 +35,36 @@ class ExpertBackend: | |
:param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto | ||
:param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto | ||
:param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto | ||
:param num_warmup_steps: the number of warmup steps for LR schedule | ||
:param num_total_steps: the total number of steps for LR schedule | ||
:param clip_grad_norm: maximum gradient norm used for clipping | ||
:param kwargs: extra parameters to be forwarded into TaskPool.__init__ | ||
""" | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
expert: nn.Module, | ||
optimizer: torch.optim.Optimizer, | ||
module: nn.Module, | ||
*, | ||
scheduler: Callable = None, | ||
optimizer: Optional[torch.optim.Optimizer] = None, | ||
scheduler: Optional[LRSchedulerBase] = None, | ||
args_schema: Tuple[BatchTensorDescriptor, ...] = None, | ||
kwargs_schema: Dict[str, BatchTensorDescriptor] = None, | ||
outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None, | ||
num_warmup_steps: int = None, | ||
num_total_steps: int = None, | ||
clip_grad_norm: float = None, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
self.expert, self.optimizer, self.name = expert, optimizer, name | ||
|
||
if scheduler is None: | ||
self.scheduler = None | ||
else: | ||
assert optimizer is not None and num_warmup_steps is not None and num_total_steps is not None | ||
mryab marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.scheduler = scheduler(self.optimizer, num_warmup_steps, num_total_steps) | ||
self.clip_grad_norm = clip_grad_norm | ||
self.name, self.module, self.optimizer, self.scheduler = name, module, optimizer, scheduler | ||
|
||
self.args_schema = args_schema = tuple(args_schema or ()) | ||
self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {}) | ||
assert args_schema or kwargs_schema, ( | ||
"expert must receive at least one positional or keyword input." | ||
f"Module must take at least one positional or keyword input." | ||
" Did you forget to provide args_schema/kwargs_schema?" | ||
) | ||
|
||
if outputs_schema is None: | ||
# run expert once to get outputs schema | ||
dummy_args = tuple(sample.make_zeros(DUMMY_BATCH_SIZE) for sample in args_schema) | ||
dummy_kwargs = {key: sample.make_zeros(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()} | ||
dummy_outputs = self.expert(*dummy_args, **dummy_kwargs) | ||
dummy_outputs = self.module(*dummy_args, **dummy_kwargs) | ||
outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs) | ||
|
||
self.forward_schema = (self.args_schema, self.kwargs_schema) # inputs for forward | ||
|
@@ -87,30 +75,22 @@ def __init__( | |
self.forward_pool = TaskPool(self.forward, name=f"{self.name}_forward", **kwargs) | ||
self.backward_pool = TaskPool(self.backward, name=f"{self.name}_backward", **kwargs) | ||
|
||
self.update_count = 0 | ||
self.examples_processed = 0 | ||
|
||
def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: | ||
""" | ||
Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually; | ||
To submit a request for asynchronous processing, please use ``ExpertBackend.forward_pool.submit_task``. | ||
|
||
Subclassing: | ||
This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``; | ||
|
||
It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``; | ||
|
||
.. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this not correct anymore? :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is arguably no different than in gradient checkpoints, and it is unlikely that we can fix it here for all cases -- without user defining custom layers. I can keep it if you insist [please reply here if so]. Alternatively, perhaps it would be best to change this todo into a warning/note. Your call? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A warning would be sufficient, I suppose There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. restored the warning |
||
.. For now, either register all buffers as outputs or avoid stateful experts | ||
|
||
""" | ||
args, kwargs = nested_pack(inputs, structure=self.forward_schema) | ||
|
||
if args[0].shape[0] == 0: | ||
raise RuntimeError("Batch should contain more than 0 samples") | ||
|
||
with torch.no_grad(): | ||
outputs = self.expert(*args, **kwargs) | ||
outputs = self.module(*args, **kwargs) | ||
|
||
# Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side | ||
return tuple(nested_flatten(outputs)) | ||
|
@@ -128,8 +108,6 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: | |
Runtime doesn't guarantee that backward will be performed in the same order and for the same data | ||
as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward. | ||
|
||
.. todo correct state handling (see forward) | ||
|
||
Please make sure to call ``ExpertBackend.apply_gradients`` here, otherwise the expert will not train | ||
""" | ||
(args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema) | ||
|
@@ -148,7 +126,7 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: | |
|
||
batch_size = args[0].size(0) | ||
|
||
outputs = self.expert(*args, **kwargs) | ||
outputs = self.module(*args, **kwargs) | ||
assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure" | ||
|
||
outputs_flat = tuple(nested_flatten(outputs)) | ||
|
@@ -163,65 +141,47 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: | |
torch.autograd.backward( | ||
outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False | ||
) | ||
self.apply_gradients(batch_size) | ||
self.on_backward(batch_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rationale: this does not necessarily apply gradients, e.g.
|
||
|
||
return tuple( | ||
x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x) for x in nested_flatten((args, kwargs)) | ||
) | ||
|
||
def apply_gradients(self, batch_size) -> None: | ||
def on_backward(self, batch_size: int) -> None: | ||
""" | ||
Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients. | ||
""" | ||
if self.clip_grad_norm is not None: | ||
torch.nn.utils.clip_grad_norm_(self.expert.parameters(), self.clip_grad_norm) | ||
|
||
self.optimizer.step() | ||
self.optimizer.zero_grad() | ||
if self.optimizer is not None: | ||
self.optimizer.step() | ||
self.optimizer.zero_grad() | ||
|
||
if self.scheduler is not None: | ||
self.scheduler.step() | ||
|
||
self.update_count += 1 | ||
self.examples_processed += batch_size | ||
|
||
def get_stats(self) -> Dict: | ||
""" | ||
Return current expert training statistics (number of updates, number of processed examples after | ||
last optimizer step) | ||
""" | ||
return {"updates": self.update_count, "examples_processed": self.examples_processed} | ||
|
||
def get_full_state(self) -> Dict: | ||
def state_dict(self) -> Dict: | ||
mryab marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Return the current state of the expert (including batch processing statistics) | ||
""" | ||
full_state = { | ||
"stats": self.get_stats(), | ||
"model": self.expert.state_dict(), | ||
"optimizer": self.optimizer.state_dict(), | ||
"scheduler": {} if self.scheduler is None else self.scheduler.state_dict(), | ||
} | ||
full_state = dict(module=self.module.state_dict()) | ||
if self.optimizer is not None: | ||
full_state["optimizer"] = self.optimizer.state_dict() | ||
if self.scheduler is not None: | ||
full_state["scheduler"] = self.scheduler.state_dict() | ||
return full_state | ||
|
||
def load_full_state(self, state_dict: Dict): | ||
if "stats" in state_dict: | ||
self.update_count = state_dict["stats"]["updates"] | ||
self.examples_processed = state_dict["stats"]["examples_processed"] | ||
else: | ||
logger.warning(f"Batch processing stats missing for expert {self.name}") | ||
def load_state_dict(self, state_dict: Dict): | ||
self.module.load_state_dict(state_dict["module"]) | ||
if self.optimizer is not None: | ||
if "optimizer" in state_dict: | ||
self.optimizer.load_state_dict(state_dict["optimizer"]) | ||
else: | ||
logger.warning(f"Optimizer state missing for {self.name}") | ||
|
||
self.expert.load_state_dict(state_dict["model"]) | ||
|
||
if "optimizer" in state_dict: | ||
self.optimizer.load_state_dict(state_dict["optimizer"]) | ||
else: | ||
logger.warning(f"Optimizer state missing for expert {self.name}") | ||
|
||
if self.scheduler is not None and "scheduler" in state_dict: | ||
self.scheduler.load_state_dict(state_dict["scheduler"]) | ||
else: | ||
logger.warning(f"Learning rate scheduler state missing for expert {self.name}") | ||
if self.scheduler is not None: | ||
if "scheduler" in state_dict: | ||
self.scheduler.load_state_dict(state_dict["scheduler"]) | ||
else: | ||
logger.warning(f"Learning rate scheduler state missing for {self.name}") | ||
|
||
def get_info(self) -> Dict[str, Any]: | ||
"""Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
|
||
|
||
class OptimizerWrapper(torch.optim.Optimizer): | ||
"""A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer""" | ||
|
||
def __init__(self, optim: torch.optim.Optimizer): | ||
object.__init__(self) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. object? In that case, it’s not a true Optimizer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. on the contrary it defines all optimizer fields and methods as properties |
||
self.optim = optim | ||
|
||
@property | ||
def defaults(self): | ||
return self.optim.defaults | ||
|
||
@property | ||
def state(self): | ||
return self.optim.state | ||
|
||
def __getstate__(self): | ||
return self.optim.__getstate__() | ||
|
||
def __setstate__(self, state): | ||
self.optim.__setstate__(state) | ||
|
||
def __repr__(self): | ||
return f"{self.__class__.__name__}({repr(self.optim)})" | ||
|
||
def state_dict(self): | ||
return self.optim.state_dict() | ||
|
||
def load_state_dict(self, state_dict: dict) -> None: | ||
return self.optim.load_state_dict(state_dict) | ||
|
||
def step(self, *args, **kwargs): | ||
return self.optim.step(*args, **kwargs) | ||
|
||
def zero_grad(self, *args, **kwargs): | ||
return self.optim.zero_grad(*args, **kwargs) | ||
|
||
@property | ||
def param_groups(self): | ||
return self.optim.param_groups | ||
|
||
def add_param_group(self, param_group: dict) -> None: | ||
return self.optim.add_param_group(param_group) | ||
|
||
|
||
class ClippingWrapper(OptimizerWrapper): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In case we need OptimizerWrapper just for this application, I’d suggest not to overcomplicate the code and just write one class for a specific use case |
||
"""A wrapper to pytorch.optimizer that clips gradients by global norm before each step""" | ||
justheuristic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__(self, optim: torch.optim.Optimizer, clip_grad_norm: float): | ||
super().__init__(optim) | ||
self.clip_grad_norm = clip_grad_norm | ||
|
||
def step(self, *args, **kwargs): | ||
parameters = tuple(param for group in self.param_groups for param in group["params"]) | ||
torch.nn.utils.clip_grad_norm_(parameters, self.clip_grad_norm) | ||
return super().step(*args, **kwargs) | ||
|
||
@classmethod | ||
def create(cls, optim_cls: type, *args, clip_grad_norm: float, **kwargs): | ||
"""Create a wrapped optimizer and wrap it with clipping""" | ||
return cls(optim=optim_cls(*args, **kwargs), clip_grad_norm=clip_grad_norm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Name changed in the source: https://github.com/huggingface/transformers/blob/v4.19.4/src/transformers/optimization.py#L75