diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 96210d94d14a..4ba3a2be7627 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -21,7 +21,7 @@ from typing import Callable, Dict, Optional, Union, Iterable -from deepspeed.runtime.utils import see_memory_usage, get_ma_status +from deepspeed.runtime.utils import see_memory_usage, get_ma_status, DummyOptim from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1 from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus @@ -293,6 +293,9 @@ def __init__( self._configure_optimizer(optimizer, model_parameters) self._configure_lr_scheduler(lr_scheduler) self._report_progress(0) + elif self.zero_optimization(): + # no optim selected but zero is enabled + self.optimizer = self._configure_zero_optimizer(optimizer=None) self._get_model_parameters() @@ -1312,9 +1315,13 @@ def _configure_zero_optimizer(self, optimizer): ), "ZeRO does not support 'fp32_allreduce': true" timers = self.timers if self.wall_clock_breakdown() else None + if optimizer is None: + optimizer = DummyOptim(list(self.module.parameters())) + if self.zero_legacy_stage1( ) and zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES: assert not self.has_moe_layers, "MoE not supported with Stage 1" + assert not isinstance(optimizer, DummyOptim), "zero stage 1 requires an optimizer" optimizer = FP16_DeepSpeedZeroOptimizer_Stage1( optimizer, @@ -1336,6 +1343,7 @@ def _configure_zero_optimizer(self, optimizer): overlap_comm = self.zero_overlap_comm() contiguous_gradients = self.zero_contiguous_gradients() round_robin_gradients = self.zero_round_robin_gradients() + assert not isinstance(optimizer, DummyOptim), "zero stage 2 requires an optimizer" # Overlap and contiguous grads are meaningless in stage 1 and are ignored if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES: @@ -1687,9 +1695,8 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False): self._start_timers(self.engine_timers.backward_timers) - assert self.optimizer is not None, ( - "must provide optimizer during " "init in order to use backward" - ) + assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ + "must provide optimizer during init in order to use backward" self._start_timers(self.engine_timers.backward_inner_timers) @@ -1835,9 +1842,9 @@ def step(self, lr_kwargs=None): self._start_timers(self.engine_timers.step_timers) - assert self.optimizer is not None, ( - "must provide optimizer during " "init in order to use step" - ) + assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ + "must provide optimizer during init in order to use step" + report_progress = self.global_rank == 0 if self.global_rank else True self._step_applied = False # assume False, will flip to True diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index a73cbf427a57..a45d0c256a76 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -32,6 +32,16 @@ torch_max_memory_reserved = torch.cuda.memory_cached +class DummyOptim(): + """ + Dummy optimizer presents model parameters as a param group, this is + primarily used to allow ZeRO-3 without an optimizer + """ + def __init__(self, params): + self.param_groups = [] + self.param_groups.append({'params': params}) + + def noop_decorator(func): return func diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 4cf5a8f50f7b..665ff0704696 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -369,10 +369,11 @@ def __init__(self, remote_device (string, optional): The initial device to store model weights e.g., ``cpu``, ``nvme``. Passing ``"cpu"`` will create the model in CPU memory. The model may still be moved to GPU based on the - offload settings for training. Defaults to the local GPU. + offload settings for training. Defaults to param offload device if a config is + defined, otherwise GPU. pin_memory (bool, optional): Potentially increase performance by using pinned memory for model weights. ``remote_device`` must be - ``"cpu"``. Defaults to ``False``. + ``"cpu"``. Defaults to pin_memory value in config, otherwise ``False``. config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration for swapping fp16 params to NVMe. config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead. @@ -478,6 +479,10 @@ def get_model(): # It is the device where parameters are fully instantiated using allgather self.local_device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) + if _ds_config is not None and _ds_config.zero_config.offload_param is not None: + remote_device = _ds_config.zero_config.offload_param[OFFLOAD_PARAM_DEVICE] + pin_memory = _ds_config.zero_config.offload_param[OFFLOAD_PARAM_PIN_MEMORY] + self._validate_remote_device(remote_device, _ds_config) # Remote device is the device where parameter partiitons are stored diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index c306114d927b..1387ced6fc8d 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -16,7 +16,7 @@ from deepspeed.utils.logging import logger from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter +from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter, DummyOptim from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.partition_parameters import _init_external_params from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS @@ -792,15 +792,18 @@ def __init__(self, self._configure_tensor_swapping(offload_optimizer_config, aio_config) see_memory_usage("Before creating fp32 partitions", force=False) - self._create_fp32_partitions() + if not isinstance(self.optimizer, DummyOptim): + self._create_fp32_partitions() see_memory_usage("After creating fp32 partitions", force=False) dist.barrier() # To support pipelined optimizer swapping - self._create_next_swappable_fp32_groups() + if not isinstance(init_optimizer, DummyOptim): + self._create_next_swappable_fp32_groups() see_memory_usage("Before initializing optimizer states", force=False) - self.initialize_optimizer_states() + if not isinstance(init_optimizer, DummyOptim): + self.initialize_optimizer_states() see_memory_usage("After initializing optimizer states", force=False) dist.barrier() @@ -920,7 +923,8 @@ def _configure_offloading(self, offload_optimizer_config, offload_param_config): ###################### offload param setup ################################## if offload_param_config is not None: - assert self.offload_optimizer, "parameter offload is only available with optimizer state offload" + if not isinstance(self.optimizer, DummyOptim): + assert self.offload_optimizer, "parameter offload is only available with optimizer state offload" self.offload_param = True self.offload_param_pin_memory = offload_param_config[ OFFLOAD_PARAM_PIN_MEMORY] diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py index 8840e4de0abb..78901eb3a40d 100755 --- a/tests/unit/simple_model.py +++ b/tests/unit/simple_model.py @@ -8,18 +8,23 @@ class SimpleModel(torch.nn.Module): - def __init__(self, hidden_dim, empty_grad=False): + def __init__(self, hidden_dim, empty_grad=False, nlayers=1): super(SimpleModel, self).__init__() - self.linear = torch.nn.Linear(hidden_dim, hidden_dim) + self.linears = torch.nn.ModuleList( + [torch.nn.Linear(hidden_dim, + hidden_dim) for i in range(nlayers)]) if empty_grad: self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) self.cross_entropy_loss = torch.nn.CrossEntropyLoss() self.empty_grad = empty_grad def forward(self, x, y): - hidden_dim = x - hidden_dim = self.linear(hidden_dim) - return self.cross_entropy_loss(hidden_dim, y) + if len(self.linears) == 1: + x = self.linears[0](x) + else: + for i, l in enumerate(self.linears): + x = self.linears[i // 2](x) + l(x) + return self.cross_entropy_loss(x, y) class Curriculum_SimpleModel(SimpleModel): diff --git a/tests/unit/test_ds_initialize.py b/tests/unit/test_ds_initialize.py index 2dfa3b481eec..04e6545b887e 100644 --- a/tests/unit/test_ds_initialize.py +++ b/tests/unit/test_ds_initialize.py @@ -4,13 +4,58 @@ from torch.optim import Optimizer, Adam, AdamW from torch.optim.lr_scheduler import _LRScheduler, LambdaLR -from simple_model import args_from_dict, SimpleModel +from simple_model import args_from_dict, SimpleModel, random_dataloader from common import distributed_test +from util import required_torch_version import deepspeed from deepspeed.ops.adam import FusedAdam from deepspeed.runtime.lr_schedules import WARMUP_LR, WarmupLR from deepspeed.runtime.config import ADAM_OPTIMIZER +from deepspeed.runtime.utils import see_memory_usage + + +@pytest.mark.parametrize('zero_stage,world_size', [(0, 1), (3, 1)]) +def test_no_optim(zero_stage, world_size): + if zero_stage == 3 and not required_torch_version(): + pytest.skip("zero-3 param offload requires at least torch 1.8") + + ds_config = { + 'train_batch_size': world_size, + 'fp16': { + 'enabled': True + }, + 'zero_optimization': { + "stage": zero_stage, + "offload_param": { + "device": "cpu" + } + } + } + # 20B test + #hidden_dim = 16 * 1024 + hidden_dim = 4 + + @distributed_test(world_size=[world_size]) + def _go(hidden_dim): + with deepspeed.zero.Init(enabled=zero_stage == 3, config_dict_or_path=ds_config): + model = SimpleModel(hidden_dim, nlayers=78) + print('total number of parameters:', + sum([p.numel() for p in model.parameters()])) + see_memory_usage('pre-init', force=True) + model, _, _, _ = deepspeed.initialize(model=model, config=ds_config) + see_memory_usage('post-init', force=True) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.half) + print(f"optimizer={model.optimizer}") + for batch in data_loader: + model(batch[0], batch[1]) + see_memory_usage('post-fwds', force=True) + + _go(hidden_dim) @pytest.mark.parametrize('optimizer_type', [None, Optimizer, Callable])