Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare GradScaler for hivemind.Optimizer #413

Merged
merged 10 commits into from
Nov 18, 2021
4 changes: 2 additions & 2 deletions hivemind/optim/collaborative.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindG
self.averager.local_step = self.collaboration_state.optimizer_step
logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")

if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
if grad_scaler is not None and not grad_scaler.are_grads_finite(self.opt):
logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
self.local_samples_accumulated = self.local_steps_accumulated = 0
self.reset_accumulated_grads_()
Expand Down Expand Up @@ -310,7 +310,7 @@ def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindG

if grad_scaler is not None:
with grad_scaler.running_global_step():
assert grad_scaler.step(self)
assert grad_scaler.step(self.opt)
else:
self.opt.step()

Expand Down
62 changes: 38 additions & 24 deletions hivemind/optim/experimental/state_averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import torch

import hivemind
from hivemind import nested_compare
from hivemind.averaging import DecentralizedAverager
from hivemind.compression import CompressionInfo, TensorRole
from hivemind.utils import get_logger, nested_flatten, nested_map, nested_pack
from hivemind.optim.grad_scaler import GradScaler
from hivemind.utils import get_logger, nested_flatten, nested_pack

logger = get_logger(__name__)

Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(
self.offload_optimizer = offload_optimizer
self.custom_gradients = custom_gradients

self._main_parameters, self._parameter_names = main_parameters, parameter_names
self.main_parameters, self.parameter_names = main_parameters, parameter_names
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made them public because these params are needed in hivemind.Optimizer.step

... and they are no more private than, for instance, opt_keys_for_averaging

self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
self.optimizer, self.scheduler = self._init_components(
param_groups, optimizer, scheduler, initialize_optimizer
Expand Down Expand Up @@ -197,7 +197,7 @@ def _init_components(
initialize_optimizer = not any(isinstance(x, torch.Tensor) for x in nested_flatten(optimizer.state_dict()))
logger.log(
self.status_loglevel,
"Initializing optimizer manually since it has no tensors in state dict"
"Initializing optimizer manually since it has no tensors in state dict. "
"To override this, please provide initialize_optimizer=False",
)

Expand Down Expand Up @@ -257,12 +257,12 @@ def _init_averaged_tensors(self) -> Sequence[torch.Tensor]:
def _init_tensor_infos(self) -> Sequence[CompressionInfo]:
"""Get CompressionInfo for each state tensor, accounting for its role and specification"""
tensor_infos = []
for param, param_name in zip(self._main_parameters, self._parameter_names):
for param, param_name in zip(self.main_parameters, self.parameter_names):
tensor_infos.append(CompressionInfo.from_tensor(param, key=param_name, role=TensorRole.PARAMETER))
for stats_name in self.opt_keys_for_averaging:
opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
assert len(opt_parameters) == len(self._parameter_names)
for param, param_name in zip(opt_parameters, self._parameter_names):
assert len(opt_parameters) == len(self.parameter_names)
for param, param_name in zip(opt_parameters, self.parameter_names):
tensor_infos.append(
CompressionInfo.from_tensor(
self.optimizer.state[param][stats_name],
Expand All @@ -284,7 +284,8 @@ def step(
delay_optimizer_step: bool = False,
averaging_round: bool = False,
delay_averaging: Optional[bool] = None,
averaging_kwargs: Optional[Dict[str, Any]] = None,
grad_scaler: Optional[GradScaler] = None,
averaging_opts: Optional[Dict[str, Any]] = None,
):
"""
Perform one or several possible actions, depending on the specified keyword args.
Expand All @@ -298,9 +299,10 @@ def step(
:param zero_grad: if True, reset local gradients after performing optimizer step
:param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
:param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
:param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
:param delay_averaging: if True, perform averaging in background and apply results in a future step
by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
:param averaging_kwargs: a dict of keyword arguments forwarded into averaging round
:param averaging_opts: a dict of keyword arguments forwarded into averaging round
"""
if delay_averaging is None:
delay_averaging = delay_optimizer_step
Expand All @@ -312,8 +314,8 @@ def step(
if delay_optimizer_step:
assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
if averaging_kwargs and not averaging_round:
logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_kwargs}")
if averaging_opts and not averaging_round:
logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
output = None

if wait_for_delayed_update:
Expand All @@ -328,19 +330,17 @@ def step(
if self.finished_averaging_round.is_set():
if not self.reuse_tensors:
self._apply_averaging_results_()
logger.log(self.status_loglevel, "Received results from background averaging round")
logger.log(self.status_loglevel, "Received parameters from background averaging round")
self.finished_averaging_round.clear()

if self.finished_optimizer_step.is_set():
if self.offload_optimizer:
self._apply_optimizer_results_()
logger.log(self.status_loglevel, "Received results from background optimizer step")
logger.log(self.status_loglevel, "Received parameters from background optimizer step")
self.finished_optimizer_step.clear()

if increment_epoch:
self.local_epoch += 1
logger.log(self.status_loglevel, f"Switching to epoch {self.local_epoch}")
self._update_scheduler()

if optimizer_step or zero_grad or averaging_round:
assert self.pending_update.done(), "Tried to perform a new update but previous update is still running"
Expand All @@ -353,7 +353,8 @@ def step(
optimizer_step,
zero_grad,
averaging_round,
**averaging_kwargs or {},
grad_scaler,
**averaging_opts or {},
)

if (optimizer_step or zero_grad) and not delay_optimizer_step:
Expand All @@ -378,20 +379,33 @@ def step(
self.finished_optimizer_step.clear()
return output

def _do(self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, **kwargs):
def _do(
self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, grad_scaler: Optional[GradScaler], **kwargs
):
"""
Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
This method is meant to be called in the background executor.
"""
try:
if optimizer_step:
logger.log(self.status_loglevel, f"Running optimizer step")
self.optimizer.step()
if grad_scaler is None:
self.optimizer.step()
else:
with grad_scaler.running_global_step():
assert grad_scaler.step(self.optimizer)

if grad_scaler is not None:
with grad_scaler.running_global_step():
assert grad_scaler.update()

self._update_scheduler()

if zero_grad:
logger.log(self.status_loglevel, f"Running zero grad")
self.optimizer.zero_grad()
if self.offload_optimizer:
for parameter in self._main_parameters:
for parameter in self.main_parameters:
if parameter.grad is not None:
parameter.grad.zero_()

Expand Down Expand Up @@ -428,7 +442,7 @@ def _load_local_grads_into_optimizer_(self):
"""Copy local gradients into the gradient buffers of the offloaded optimizer"""
assert self.offload_optimizer, "Loading into offloaded optimizer requires using offloaded optimizer"
opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
for main_param, opt_param in zip(self._main_parameters, opt_parameters):
for main_param, opt_param in zip(self.main_parameters, opt_parameters):
if main_param.grad is not None:
opt_param.grad.copy_(main_param.grad, non_blocking=True)

Expand All @@ -438,8 +452,8 @@ def _apply_optimizer_results_(self):
assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
with self.lock_averaged_tensors:
offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
assert len(offloaded_parameters) == len(self._main_parameters), "opt parameters changed during training"
for main_param, offloaded_param in zip(self._main_parameters, offloaded_parameters):
assert len(offloaded_parameters) == len(self.main_parameters), "Optimizer parameters changed during training"
for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
main_param.copy_(offloaded_param, non_blocking=True)

@torch.no_grad()
Expand Down Expand Up @@ -471,7 +485,7 @@ def get_current_state(self):
)
parameter_infos = [
CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
for param, key in zip(optimized_parameters, self._parameter_names)
for param, key in zip(optimized_parameters, self.parameter_names)
]
extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
extra_infos = [
Expand All @@ -496,7 +510,7 @@ def load_state_from_peers(self, **kwargs):
Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
:returns: whether or the averager succeeded in loading parameters
"""
parameters_and_extras = tuple(chain(self._main_parameters, self.extra_tensors))
parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
num_parameters_and_extras = len(parameters_and_extras)

loaded_state = super().load_state_from_peers(**kwargs)
Expand Down
24 changes: 14 additions & 10 deletions hivemind/optim/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import torch
from torch.cuda.amp import GradScaler as TorchGradScaler
from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
from torch.optim import Optimizer
from torch.optim import Optimizer as TorchOptimizer

from hivemind.optim.base import DecentralizedOptimizerBase
from hivemind.utils.logging import get_logger

logger = get_logger(__name__)


class HivemindGradScaler(TorchGradScaler):
class GradScaler(TorchGradScaler):
"""
A thin wrapper over pytorch GradScaler that supports hivemind-style training with CollaborativeOptimizer, namely:
- bypass .unscale_ and .update calls in order to accumulate gradients over several steps
Expand All @@ -33,7 +33,7 @@ def running_global_step(self):
finally:
self._is_running_global_step = was_running

def unscale_(self, optimizer: Optimizer) -> bool:
def unscale_(self, optimizer: TorchOptimizer) -> bool:
assert isinstance(optimizer, DecentralizedOptimizerBase)
if self._is_running_global_step:
super().unscale_(optimizer.opt)
Expand All @@ -43,11 +43,10 @@ def unscale_(self, optimizer: Optimizer) -> bool:
self._optimizer_states_to_reset.add(id(optimizer))
return False

def step(self, optimizer: Optimizer, *args, **kwargs) -> bool:
assert isinstance(optimizer, DecentralizedOptimizerBase)
def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
if self._is_running_global_step:
if self.are_grads_finite(optimizer):
super().step(optimizer.opt, *args, **kwargs)
super().step(optimizer, *args, **kwargs)
else:
logger.warning("Skipping global step due to gradient over/underflow")
return True
Expand All @@ -72,12 +71,17 @@ def update(self, new_scale: Optional[float] = None) -> bool:
return False

def _unscale_grads_(
self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
) -> Dict[torch.device, torch.Tensor]:
# note: the code below sets allow_fp16=True to allow training with master weights (partially) in fp16
# inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)

def are_grads_finite(self, optimizer: DecentralizedOptimizerBase) -> bool:
assert isinstance(optimizer, DecentralizedOptimizerBase)
return not sum(v.item() for v in self._check_inf_per_device(optimizer.opt).values())
def are_grads_finite(self, optimizer: TorchOptimizer) -> bool:
return not sum(v.item() for v in self._check_inf_per_device(optimizer).values())


class HivemindGradScaler(GradScaler):
def __init__(self, *args, **kwargs):
logger.warning("HivemindGradScaler was renamed to hivemind.GradScaler, this reference will be removed in v1.1")
super().__init__(*args, **kwargs)