Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify ExpertBackend interface #483

Merged
merged 39 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e5a3e46
Increase default update_period to 30s, set default expiration to 2 * …
justheuristic Jun 12, 2022
637fb01
black-isort
justheuristic Jun 12, 2022
372d915
review
mryab Jun 12, 2022
98d4952
review
mryab Jun 12, 2022
b43c243
typo
justheuristic Jun 12, 2022
108a24b
add expiration param
justheuristic Jun 12, 2022
a5aa6f9
black-isort
justheuristic Jun 12, 2022
e940911
more requests
mryab Jun 12, 2022
2356bc1
review
mryab Jun 12, 2022
771ebc1
Update tests/test_dht_experts.py
mryab Jun 12, 2022
8fb0986
py39
justheuristic Jun 12, 2022
c210ecb
Merge remote-tracking branch 'origin/default_expiration' into default…
justheuristic Jun 12, 2022
ec22eda
- rename num_total_steps -> num_training steps (to match the source)
justheuristic Jun 14, 2022
a0622bd
Merge branch 'master' into demo
justheuristic Jun 14, 2022
fa2da45
rename
justheuristic Jun 14, 2022
6c49fe9
black-isort
justheuristic Jun 14, 2022
2c77de0
rename
justheuristic Jun 14, 2022
b1873e1
un-hardcode experts from private interface on server side
justheuristic Jun 14, 2022
9f3187f
un-hardcode experts from private interface on server side
justheuristic Jun 14, 2022
d87a7b1
un-hardcode experts from private interface on server side
justheuristic Jun 14, 2022
65d622b
wrap optimizer, not scheduler
mryab Jun 14, 2022
a00fb9e
ModuleBackend
justheuristic Jun 14, 2022
5569c42
ModuleBackend
justheuristic Jun 14, 2022
add83b5
Update hivemind/moe/server/layers/optim.py
justheuristic Jun 14, 2022
9664d05
fix import
mryab Jun 14, 2022
c30bc6c
Merge remote-tracking branch 'origin/demo' into demo
justheuristic Jun 14, 2022
7aed0a8
review
mryab Jun 14, 2022
fa48f2c
review
justheuristic Jun 14, 2022
5601a95
review
mryab Jun 14, 2022
decf1b4
Update hivemind/moe/server/server.py
justheuristic Jun 14, 2022
409e035
review
justheuristic Jun 14, 2022
891a83b
Merge remote-tracking branch 'origin/demo' into demo
justheuristic Jun 14, 2022
dd6fc94
black-isort
justheuristic Jun 14, 2022
a9b7643
review
mryab Jun 15, 2022
8a2e1f2
Merge branch 'master' into demo
justheuristic Jun 15, 2022
efeb31b
review
mryab Jun 15, 2022
a63e8ef
Merge remote-tracking branch 'origin/demo' into demo
justheuristic Jun 15, 2022
04af589
review
justheuristic Jun 15, 2022
e05d3dc
review
justheuristic Jun 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def benchmark_throughput(
expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
experts[f"expert.{i}"] = ExpertBackend(
name=f"expert.{i}",
expert=expert,
module=expert,
optimizer=torch.optim.Adam(expert.parameters()),
args_schema=(BatchTensorDescriptor(hid_dim),),
outputs_schema=BatchTensorDescriptor(hid_dim),
Expand Down
3 changes: 2 additions & 1 deletion hivemind/hivemind_cli/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def main():
help='Server will report experts to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None,
help='DHT entries will expire after this many seconds')
parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
parser.add_argument('--num_training_steps', type=int, required=False, help='The total number of steps for LR schedule')
Copy link
Member Author

@justheuristic justheuristic Jun 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')

parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
Expand Down
4 changes: 2 additions & 2 deletions hivemind/moe/server/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
expert_dir = Path(tmpdirname) / expert_name
expert_dir.mkdir()
checkpoint_name = expert_dir / f"checkpoint_{timestamp}.pt"
torch.save(expert_backend.get_full_state(), checkpoint_name)
torch.save(expert_backend.state_dict(), checkpoint_name)
os.symlink(checkpoint_name, expert_dir / "checkpoint_last.pt")
copy_tree(tmpdirname, str(checkpoint_dir))

Expand All @@ -70,6 +70,6 @@ def load_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
checkpoints_folder = checkpoint_dir / expert_name
latest_checkpoint = checkpoints_folder / "checkpoint_last.pt"
if latest_checkpoint.exists():
expert.load_full_state(torch.load(latest_checkpoint))
expert.load_state_dict(torch.load(latest_checkpoint))
else:
logger.warning(f"Failed to load checkpoint for expert {expert_name}")
108 changes: 34 additions & 74 deletions hivemind/moe/server/expert_backend.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any, Callable, Dict, Sequence, Tuple, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import torch
from torch import nn

from hivemind.moe.server.task_pool import TaskPool
from hivemind.optim.state_averager import LRSchedulerBase
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this import is not actually related to hivemind.optim, I’d suggest so simply inline the statement that is being imported

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

from hivemind.utils.logging import get_logger
from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
Expand All @@ -20,7 +21,7 @@ class ExpertBackend:
- backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and **update expert**. Also batched.
- get_info - return expert metadata. Not batched.

:param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
:param module: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:

- Experts must always receive the same set of args and kwargs and produce output tensors of same type
- All args, kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
Expand All @@ -34,49 +35,36 @@ class ExpertBackend:
:param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
:param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
:param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
:param num_warmup_steps: the number of warmup steps for LR schedule
:param num_total_steps: the total number of steps for LR schedule
:param clip_grad_norm: maximum gradient norm used for clipping
:param kwargs: extra parameters to be forwarded into TaskPool.__init__
"""

def __init__(
self,
name: str,
expert: nn.Module,
optimizer: torch.optim.Optimizer,
module: nn.Module,
*,
scheduler: Callable = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerBase] = None,
args_schema: Tuple[BatchTensorDescriptor, ...] = None,
kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
num_warmup_steps: int = None,
num_total_steps: int = None,
clip_grad_norm: float = None,
**kwargs,
):
super().__init__()
self.expert, self.optimizer, self.name = expert, optimizer, name

if scheduler is None:
self.scheduler = None
else:
assert optimizer is not None and num_warmup_steps is not None and num_total_steps is not None
mryab marked this conversation as resolved.
Show resolved Hide resolved
self.scheduler = scheduler(self.optimizer, num_warmup_steps, num_total_steps)
self.clip_grad_norm = clip_grad_norm
self.name, self.module, self.optimizer, self.scheduler = name, module, optimizer, scheduler

self.args_schema = args_schema = tuple(args_schema or ())
self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
assert args_schema or kwargs_schema, (
"expert must receive at least one positional or keyword input."
f"Module must take at least one positional or keyword input."
" Did you forget to provide args_schema/kwargs_schema?"
)

if outputs_schema is None:
# run expert once to get outputs schema
dummy_args = tuple(sample.make_zeros(DUMMY_BATCH_SIZE) for sample in args_schema)
dummy_kwargs = {key: sample.make_zeros(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
dummy_outputs = self.module(*dummy_args, **dummy_kwargs)
outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)

self.forward_schema = (self.args_schema, self.kwargs_schema) # inputs for forward
Expand All @@ -87,30 +75,22 @@ def __init__(
self.forward_pool = TaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
self.backward_pool = TaskPool(self.backward, name=f"{self.name}_backward", **kwargs)

self.update_count = 0
self.examples_processed = 0

def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""
Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
To submit a request for asynchronous processing, please use ``ExpertBackend.forward_pool.submit_task``.

Subclassing:
This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``;

It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;

.. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not correct anymore? :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is arguably no different than in gradient checkpoints, and it is unlikely that we can fix it here for all cases -- without user defining custom layers. I can keep it if you insist [please reply here if so]. Alternatively, perhaps it would be best to change this todo into a warning/note. Your call?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A warning would be sufficient, I suppose

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restored the warning

.. For now, either register all buffers as outputs or avoid stateful experts

"""
args, kwargs = nested_pack(inputs, structure=self.forward_schema)

if args[0].shape[0] == 0:
raise RuntimeError("Batch should contain more than 0 samples")

with torch.no_grad():
outputs = self.expert(*args, **kwargs)
outputs = self.module(*args, **kwargs)

# Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
return tuple(nested_flatten(outputs))
Expand All @@ -128,8 +108,6 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
Runtime doesn't guarantee that backward will be performed in the same order and for the same data
as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.

.. todo correct state handling (see forward)

Please make sure to call ``ExpertBackend.apply_gradients`` here, otherwise the expert will not train
"""
(args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
Expand All @@ -148,7 +126,7 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:

batch_size = args[0].size(0)

outputs = self.expert(*args, **kwargs)
outputs = self.module(*args, **kwargs)
assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"

outputs_flat = tuple(nested_flatten(outputs))
Expand All @@ -163,65 +141,47 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
torch.autograd.backward(
outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False
)
self.apply_gradients(batch_size)
self.on_backward(batch_size)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rationale: this does not necessarily apply gradients, e.g.

  • virtual batching applies gradients once every k steps
  • pretrained models do not apply the returned gradients


return tuple(
x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x) for x in nested_flatten((args, kwargs))
)

def apply_gradients(self, batch_size) -> None:
def on_backward(self, batch_size: int) -> None:
"""
Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
"""
if self.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(self.expert.parameters(), self.clip_grad_norm)

self.optimizer.step()
self.optimizer.zero_grad()
if self.optimizer is not None:
self.optimizer.step()
self.optimizer.zero_grad()

if self.scheduler is not None:
self.scheduler.step()

self.update_count += 1
self.examples_processed += batch_size

def get_stats(self) -> Dict:
"""
Return current expert training statistics (number of updates, number of processed examples after
last optimizer step)
"""
return {"updates": self.update_count, "examples_processed": self.examples_processed}

def get_full_state(self) -> Dict:
def state_dict(self) -> Dict:
mryab marked this conversation as resolved.
Show resolved Hide resolved
"""
Return the current state of the expert (including batch processing statistics)
"""
full_state = {
"stats": self.get_stats(),
"model": self.expert.state_dict(),
"optimizer": self.optimizer.state_dict(),
"scheduler": {} if self.scheduler is None else self.scheduler.state_dict(),
}
full_state = dict(module=self.module.state_dict())
if self.optimizer is not None:
full_state["optimizer"] = self.optimizer.state_dict()
if self.scheduler is not None:
full_state["scheduler"] = self.scheduler.state_dict()
return full_state

def load_full_state(self, state_dict: Dict):
if "stats" in state_dict:
self.update_count = state_dict["stats"]["updates"]
self.examples_processed = state_dict["stats"]["examples_processed"]
else:
logger.warning(f"Batch processing stats missing for expert {self.name}")
def load_state_dict(self, state_dict: Dict):
self.module.load_state_dict(state_dict["module"])
if self.optimizer is not None:
if "optimizer" in state_dict:
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
logger.warning(f"Optimizer state missing for {self.name}")

self.expert.load_state_dict(state_dict["model"])

if "optimizer" in state_dict:
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
logger.warning(f"Optimizer state missing for expert {self.name}")

if self.scheduler is not None and "scheduler" in state_dict:
self.scheduler.load_state_dict(state_dict["scheduler"])
else:
logger.warning(f"Learning rate scheduler state missing for expert {self.name}")
if self.scheduler is not None:
if "scheduler" in state_dict:
self.scheduler.load_state_dict(state_dict["scheduler"])
else:
logger.warning(f"Learning rate scheduler state missing for {self.name}")

def get_info(self) -> Dict[str, Any]:
"""Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
Expand Down
63 changes: 63 additions & 0 deletions hivemind/moe/server/layers/optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch


class OptimizerWrapper(torch.optim.Optimizer):
"""A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer"""

def __init__(self, optim: torch.optim.Optimizer):
object.__init__(self)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

object? In that case, it’s not a true Optimizer

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the contrary it defines all optimizer fields and methods as properties

self.optim = optim

@property
def defaults(self):
return self.optim.defaults

@property
def state(self):
return self.optim.state

def __getstate__(self):
return self.optim.__getstate__()

def __setstate__(self, state):
self.optim.__setstate__(state)

def __repr__(self):
return f"{self.__class__.__name__}({repr(self.optim)})"

def state_dict(self):
return self.optim.state_dict()

def load_state_dict(self, state_dict: dict) -> None:
return self.optim.load_state_dict(state_dict)

def step(self, *args, **kwargs):
return self.optim.step(*args, **kwargs)

def zero_grad(self, *args, **kwargs):
return self.optim.zero_grad(*args, **kwargs)

@property
def param_groups(self):
return self.optim.param_groups

def add_param_group(self, param_group: dict) -> None:
return self.optim.add_param_group(param_group)


class ClippingWrapper(OptimizerWrapper):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case we need OptimizerWrapper just for this application, I’d suggest not to overcomplicate the code and just write one class for a specific use case

"""A wrapper to pytorch.optimizer that clips gradients by global norm before each step"""
justheuristic marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, optim: torch.optim.Optimizer, clip_grad_norm: float):
super().__init__(optim)
self.clip_grad_norm = clip_grad_norm

def step(self, *args, **kwargs):
parameters = tuple(param for group in self.param_groups for param in group["params"])
torch.nn.utils.clip_grad_norm_(parameters, self.clip_grad_norm)
return super().step(*args, **kwargs)

@classmethod
def create(cls, optim_cls: type, *args, clip_grad_norm: float, **kwargs):
"""Create a wrapped optimizer and wrap it with clipping"""
return cls(optim=optim_cls(*args, **kwargs), clip_grad_norm=clip_grad_norm)
2 changes: 1 addition & 1 deletion hivemind/moe/server/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def run(self):
pool.start()
if self.device is not None:
for expert_backend in self.expert_backends.values():
expert_backend.expert.to(self.device)
expert_backend.module.to(self.device)

with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
try:
Expand Down
Loading