From 9b7db56efad7edf761a412838405e2ae1b3c116e Mon Sep 17 00:00:00 2001 From: justheuristic Date: Thu, 18 Nov 2021 19:12:52 +0300 Subject: [PATCH 1/9] prepare GradScaler for hivemind.Optimizer --- hivemind/optim/collaborative.py | 4 ++-- hivemind/optim/grad_scaler.py | 27 +++++++++++++++------------ 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/hivemind/optim/collaborative.py b/hivemind/optim/collaborative.py index a816a8aab..967da01d0 100644 --- a/hivemind/optim/collaborative.py +++ b/hivemind/optim/collaborative.py @@ -245,7 +245,7 @@ def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindG self.averager.local_step = self.collaboration_state.optimizer_step logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.") - if grad_scaler is not None and not grad_scaler.are_grads_finite(self): + if grad_scaler is not None and not grad_scaler.are_grads_finite(self.opt): logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients") self.local_samples_accumulated = self.local_steps_accumulated = 0 self.reset_accumulated_grads_() @@ -310,7 +310,7 @@ def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindG if grad_scaler is not None: with grad_scaler.running_global_step(): - assert grad_scaler.step(self) + assert grad_scaler.step(self.opt) else: self.opt.step() diff --git a/hivemind/optim/grad_scaler.py b/hivemind/optim/grad_scaler.py index 1b7089c1a..b0b4f4c98 100644 --- a/hivemind/optim/grad_scaler.py +++ b/hivemind/optim/grad_scaler.py @@ -2,17 +2,16 @@ from typing import Dict, Optional import torch +from torch.optim import Optimizer as TorchOptimizer from torch.cuda.amp import GradScaler as TorchGradScaler from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state -from torch.optim import Optimizer -from hivemind.optim.base import DecentralizedOptimizerBase from hivemind.utils.logging import get_logger logger = get_logger(__name__) -class HivemindGradScaler(TorchGradScaler): +class GradScaler(TorchGradScaler): """ A thin wrapper over pytorch GradScaler that supports hivemind-style training with CollaborativeOptimizer, namely: - bypass .unscale_ and .update calls in order to accumulate gradients over several steps @@ -33,8 +32,8 @@ def running_global_step(self): finally: self._is_running_global_step = was_running - def unscale_(self, optimizer: Optimizer) -> bool: - assert isinstance(optimizer, DecentralizedOptimizerBase) + def unscale_(self, optimizer: TorchOptimizer) -> bool: + assert hasattr(optimizer, "opt"), "hivemind.GradScaler only supports hivemind optimizer wrappers" if self._is_running_global_step: super().unscale_(optimizer.opt) return True @@ -43,11 +42,10 @@ def unscale_(self, optimizer: Optimizer) -> bool: self._optimizer_states_to_reset.add(id(optimizer)) return False - def step(self, optimizer: Optimizer, *args, **kwargs) -> bool: - assert isinstance(optimizer, DecentralizedOptimizerBase) + def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool: if self._is_running_global_step: if self.are_grads_finite(optimizer): - super().step(optimizer.opt, *args, **kwargs) + super().step(optimizer, *args, **kwargs) else: logger.warning("Skipping global step due to gradient over/underflow") return True @@ -72,12 +70,17 @@ def update(self, new_scale: Optional[float] = None) -> bool: return False def _unscale_grads_( - self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool + self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool ) -> Dict[torch.device, torch.Tensor]: # note: the code below sets allow_fp16=True to allow training with master weights (partially) in fp16 # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True) - def are_grads_finite(self, optimizer: DecentralizedOptimizerBase) -> bool: - assert isinstance(optimizer, DecentralizedOptimizerBase) - return not sum(v.item() for v in self._check_inf_per_device(optimizer.opt).values()) + def are_grads_finite(self, optimizer: TorchOptimizer) -> bool: + return not sum(v.item() for v in self._check_inf_per_device(optimizer).values()) + + +class HivemindGradScaler(GradScaler): + def __init__(self, *args, **kwargs): + logger.warning("HivemindGradScaler was renamed to hivemind.GradScaler, this reference will be removed in v1.1") + super().__init__(*args, **kwargs) From db53507dda1b67b1b31c0fa12efd8a2e07716380 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Thu, 18 Nov 2021 19:15:13 +0300 Subject: [PATCH 2/9] support GradScaler in state_averager.py --- hivemind/optim/experimental/state_averager.py | 59 +++++++++++-------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/hivemind/optim/experimental/state_averager.py b/hivemind/optim/experimental/state_averager.py index c81317309..038b8bd1e 100644 --- a/hivemind/optim/experimental/state_averager.py +++ b/hivemind/optim/experimental/state_averager.py @@ -9,10 +9,10 @@ import torch import hivemind -from hivemind import nested_compare from hivemind.averaging import DecentralizedAverager from hivemind.compression import CompressionInfo, TensorRole -from hivemind.utils import get_logger, nested_flatten, nested_map, nested_pack +from hivemind.optim.grad_scaler import GradScaler +from hivemind.utils import get_logger, nested_flatten, nested_pack logger = get_logger(__name__) @@ -100,7 +100,7 @@ def __init__( self.offload_optimizer = offload_optimizer self.custom_gradients = custom_gradients - self._main_parameters, self._parameter_names = main_parameters, parameter_names + self.main_parameters, self.parameter_names = main_parameters, parameter_names self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters)) self.optimizer, self.scheduler = self._init_components( param_groups, optimizer, scheduler, initialize_optimizer @@ -197,7 +197,7 @@ def _init_components( initialize_optimizer = not any(isinstance(x, torch.Tensor) for x in nested_flatten(optimizer.state_dict())) logger.log( self.status_loglevel, - "Initializing optimizer manually since it has no tensors in state dict" + "Initializing optimizer manually since it has no tensors in state dict. " "To override this, please provide initialize_optimizer=False", ) @@ -257,12 +257,12 @@ def _init_averaged_tensors(self) -> Sequence[torch.Tensor]: def _init_tensor_infos(self) -> Sequence[CompressionInfo]: """Get CompressionInfo for each state tensor, accounting for its role and specification""" tensor_infos = [] - for param, param_name in zip(self._main_parameters, self._parameter_names): + for param, param_name in zip(self.main_parameters, self.parameter_names): tensor_infos.append(CompressionInfo.from_tensor(param, key=param_name, role=TensorRole.PARAMETER)) for stats_name in self.opt_keys_for_averaging: opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]] - assert len(opt_parameters) == len(self._parameter_names) - for param, param_name in zip(opt_parameters, self._parameter_names): + assert len(opt_parameters) == len(self.parameter_names) + for param, param_name in zip(opt_parameters, self.parameter_names): tensor_infos.append( CompressionInfo.from_tensor( self.optimizer.state[param][stats_name], @@ -284,7 +284,8 @@ def step( delay_optimizer_step: bool = False, averaging_round: bool = False, delay_averaging: Optional[bool] = None, - averaging_kwargs: Optional[Dict[str, Any]] = None, + grad_scaler: Optional[GradScaler] = None, + averaging_opts: Optional[Dict[str, Any]] = None, ): """ Perform one or several possible actions, depending on the specified keyword args. @@ -298,9 +299,10 @@ def step( :param zero_grad: if True, reset local gradients after performing optimizer step :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers + :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_ :param delay_averaging: if True, perform averaging in background and apply results in a future step by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase. - :param averaging_kwargs: a dict of keyword arguments forwarded into averaging round + :param averaging_opts: a dict of keyword arguments forwarded into averaging round """ if delay_averaging is None: delay_averaging = delay_optimizer_step @@ -312,8 +314,8 @@ def step( if delay_optimizer_step: assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer" assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed" - if averaging_kwargs and not averaging_round: - logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_kwargs}") + if averaging_opts and not averaging_round: + logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}") output = None if wait_for_delayed_update: @@ -328,19 +330,17 @@ def step( if self.finished_averaging_round.is_set(): if not self.reuse_tensors: self._apply_averaging_results_() - logger.log(self.status_loglevel, "Received results from background averaging round") + logger.log(self.status_loglevel, "Received parameters from background averaging round") self.finished_averaging_round.clear() if self.finished_optimizer_step.is_set(): if self.offload_optimizer: self._apply_optimizer_results_() - logger.log(self.status_loglevel, "Received results from background optimizer step") + logger.log(self.status_loglevel, "Received parameters from background optimizer step") self.finished_optimizer_step.clear() if increment_epoch: self.local_epoch += 1 - logger.log(self.status_loglevel, f"Switching to epoch {self.local_epoch}") - self._update_scheduler() if optimizer_step or zero_grad or averaging_round: assert self.pending_update.done(), "Tried to perform a new update but previous update is still running" @@ -353,7 +353,7 @@ def step( optimizer_step, zero_grad, averaging_round, - **averaging_kwargs or {}, + **averaging_opts or {}, ) if (optimizer_step or zero_grad) and not delay_optimizer_step: @@ -378,7 +378,7 @@ def step( self.finished_optimizer_step.clear() return output - def _do(self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, **kwargs): + def _do(self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, grad_scaler: Optional[GradScaler], **kwargs): """ Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional. This method is meant to be called in the background executor. @@ -386,12 +386,23 @@ def _do(self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, **kw try: if optimizer_step: logger.log(self.status_loglevel, f"Running optimizer step") - self.optimizer.step() + if grad_scaler is None: + self.optimizer.step() + else: + with grad_scaler.running_global_step(): + assert grad_scaler.step(self.optimizer) + + if grad_scaler is not None: + with grad_scaler.running_global_step(): + assert grad_scaler.update() + + self._update_scheduler() + if zero_grad: logger.log(self.status_loglevel, f"Running zero grad") self.optimizer.zero_grad() if self.offload_optimizer: - for parameter in self._main_parameters: + for parameter in self.main_parameters: if parameter.grad is not None: parameter.grad.zero_() @@ -428,7 +439,7 @@ def _load_local_grads_into_optimizer_(self): """Copy local gradients into the gradient buffers of the offloaded optimizer""" assert self.offload_optimizer, "Loading into offloaded optimizer requires using offloaded optimizer" opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]] - for main_param, opt_param in zip(self._main_parameters, opt_parameters): + for main_param, opt_param in zip(self.main_parameters, opt_parameters): if main_param.grad is not None: opt_param.grad.copy_(main_param.grad, non_blocking=True) @@ -438,8 +449,8 @@ def _apply_optimizer_results_(self): assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer" with self.lock_averaged_tensors: offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]] - assert len(offloaded_parameters) == len(self._main_parameters), "opt parameters changed during training" - for main_param, offloaded_param in zip(self._main_parameters, offloaded_parameters): + assert len(offloaded_parameters) == len(self.main_parameters), "opt parameters changed during training" + for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters): main_param.copy_(offloaded_param, non_blocking=True) @torch.no_grad() @@ -471,7 +482,7 @@ def get_current_state(self): ) parameter_infos = [ CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER) - for param, key in zip(optimized_parameters, self._parameter_names) + for param, key in zip(optimized_parameters, self.parameter_names) ] extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors) extra_infos = [ @@ -496,7 +507,7 @@ def load_state_from_peers(self, **kwargs): Attempt to download the latest optimizer state from peers and update trainer parameters/statistics. :returns: whether or the averager succeeded in loading parameters """ - parameters_and_extras = tuple(chain(self._main_parameters, self.extra_tensors)) + parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors)) num_parameters_and_extras = len(parameters_and_extras) loaded_state = super().load_state_from_peers(**kwargs) From d667541fff9053659b4b767987589bc34ccab021 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Thu, 18 Nov 2021 19:16:41 +0300 Subject: [PATCH 3/9] black-isort --- hivemind/optim/experimental/state_averager.py | 4 +++- hivemind/optim/grad_scaler.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/hivemind/optim/experimental/state_averager.py b/hivemind/optim/experimental/state_averager.py index 038b8bd1e..eddc15291 100644 --- a/hivemind/optim/experimental/state_averager.py +++ b/hivemind/optim/experimental/state_averager.py @@ -378,7 +378,9 @@ def step( self.finished_optimizer_step.clear() return output - def _do(self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, grad_scaler: Optional[GradScaler], **kwargs): + def _do( + self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, grad_scaler: Optional[GradScaler], **kwargs + ): """ Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional. This method is meant to be called in the background executor. diff --git a/hivemind/optim/grad_scaler.py b/hivemind/optim/grad_scaler.py index b0b4f4c98..d1eab18eb 100644 --- a/hivemind/optim/grad_scaler.py +++ b/hivemind/optim/grad_scaler.py @@ -2,9 +2,9 @@ from typing import Dict, Optional import torch -from torch.optim import Optimizer as TorchOptimizer from torch.cuda.amp import GradScaler as TorchGradScaler from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state +from torch.optim import Optimizer as TorchOptimizer from hivemind.utils.logging import get_logger From d806748adb6e284d6575b3b560d03c48e5d9e3e9 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Thu, 18 Nov 2021 19:22:24 +0300 Subject: [PATCH 4/9] black --- hivemind/optim/experimental/state_averager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/hivemind/optim/experimental/state_averager.py b/hivemind/optim/experimental/state_averager.py index eddc15291..75a8ced8e 100644 --- a/hivemind/optim/experimental/state_averager.py +++ b/hivemind/optim/experimental/state_averager.py @@ -353,6 +353,7 @@ def step( optimizer_step, zero_grad, averaging_round, + grad_scaler, **averaging_opts or {}, ) From 50cc792c7855afc74b44df3450bb693cfafd7e7c Mon Sep 17 00:00:00 2001 From: justheuristic Date: Thu, 18 Nov 2021 19:22:44 +0300 Subject: [PATCH 5/9] Update hivemind/optim/experimental/state_averager.py Co-authored-by: Max Ryabinin --- hivemind/optim/experimental/state_averager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hivemind/optim/experimental/state_averager.py b/hivemind/optim/experimental/state_averager.py index 75a8ced8e..f9a42fa5c 100644 --- a/hivemind/optim/experimental/state_averager.py +++ b/hivemind/optim/experimental/state_averager.py @@ -452,7 +452,7 @@ def _apply_optimizer_results_(self): assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer" with self.lock_averaged_tensors: offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]] - assert len(offloaded_parameters) == len(self.main_parameters), "opt parameters changed during training" + assert len(offloaded_parameters) == len(self.main_parameters), "Optimizer parameters changed during training" for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters): main_param.copy_(offloaded_param, non_blocking=True) From c62370b551c97947ea76bc41d879721e92ebe414 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Thu, 18 Nov 2021 19:25:32 +0300 Subject: [PATCH 6/9] review --- hivemind/optim/grad_scaler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hivemind/optim/grad_scaler.py b/hivemind/optim/grad_scaler.py index d1eab18eb..a20807ff3 100644 --- a/hivemind/optim/grad_scaler.py +++ b/hivemind/optim/grad_scaler.py @@ -33,7 +33,7 @@ def running_global_step(self): self._is_running_global_step = was_running def unscale_(self, optimizer: TorchOptimizer) -> bool: - assert hasattr(optimizer, "opt"), "hivemind.GradScaler only supports hivemind optimizer wrappers" + assert isinstance(optimizer, DecentralizedOptimizerBase) if self._is_running_global_step: super().unscale_(optimizer.opt) return True From 24fabe109a5046bd2d67a1895218665e91577306 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Thu, 18 Nov 2021 19:26:02 +0300 Subject: [PATCH 7/9] review --- hivemind/optim/grad_scaler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/hivemind/optim/grad_scaler.py b/hivemind/optim/grad_scaler.py index a20807ff3..11863f8a3 100644 --- a/hivemind/optim/grad_scaler.py +++ b/hivemind/optim/grad_scaler.py @@ -6,6 +6,7 @@ from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state from torch.optim import Optimizer as TorchOptimizer +from hivemind import DecentralizedOptimizerBase from hivemind.utils.logging import get_logger logger = get_logger(__name__) From f1da4ef50df543d74ea30d6488cd55b603e3bb46 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Thu, 18 Nov 2021 19:26:53 +0300 Subject: [PATCH 8/9] black-isort --- hivemind/optim/grad_scaler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hivemind/optim/grad_scaler.py b/hivemind/optim/grad_scaler.py index 11863f8a3..c575a5460 100644 --- a/hivemind/optim/grad_scaler.py +++ b/hivemind/optim/grad_scaler.py @@ -6,7 +6,7 @@ from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state from torch.optim import Optimizer as TorchOptimizer -from hivemind import DecentralizedOptimizerBase +from hivemind.optim.base import DecentralizedOptimizerBase from hivemind.utils.logging import get_logger logger = get_logger(__name__) From b76b710150eb02d7896a89fbba1016df933da1cd Mon Sep 17 00:00:00 2001 From: justheuristic Date: Thu, 18 Nov 2021 19:29:59 +0300 Subject: [PATCH 9/9] black --- hivemind/optim/experimental/state_averager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/hivemind/optim/experimental/state_averager.py b/hivemind/optim/experimental/state_averager.py index f9a42fa5c..8e979ee57 100644 --- a/hivemind/optim/experimental/state_averager.py +++ b/hivemind/optim/experimental/state_averager.py @@ -452,7 +452,9 @@ def _apply_optimizer_results_(self): assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer" with self.lock_averaged_tensors: offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]] - assert len(offloaded_parameters) == len(self.main_parameters), "Optimizer parameters changed during training" + assert len(offloaded_parameters) == len( + self.main_parameters + ), "Optimizer parameters changed during training" for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters): main_param.copy_(offloaded_param, non_blocking=True)