diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 10edf4ba818a..2a242fcc2ecf 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -252,11 +252,15 @@ def __init__(self, self.sub_group_size = sub_group_size self.sub_group_to_group_id = {} - see_memory_usage("Before creating fp16 partitions", force=False) - self._create_fp16_partitions_with_defragmentation() + + # Trainable parameters + self.trainable_param_groups = self._get_trainable_parameter_groups() + + see_memory_usage("Before creating fp16 partitions", force=True) + self._create_fp16_partitions_with_defragmentation(self.trainable_param_groups) num_fp16_subgroups = len(self.fp16_partitioned_groups_flat) see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}", - force=False) + force=True) # Optimizer tensor swapping if self.swap_optimizer: @@ -350,19 +354,28 @@ def __init__(self, def destroy(self): self.parameter_offload.destroy() + def _get_trainable_parameter_groups(self): + param_groups = [] + for param_group in self.optimizer.param_groups: + trainable_params = { + "params": [p for p in param_group["params"] if p.requires_grad] + } + param_groups.append(trainable_params) + return param_groups + def _setup_for_real_optimizer(self): - see_memory_usage("Before creating fp32 partitions", force=False) + see_memory_usage("Before creating fp32 partitions", force=True) self._create_fp32_partitions() - see_memory_usage("After creating fp32 partitions", force=False) + see_memory_usage("After creating fp32 partitions", force=True) dist.barrier() # To support pipelined optimizer swapping self._create_next_swappable_fp32_groups() - see_memory_usage("Before initializing optimizer states", force=False) + see_memory_usage("Before initializing optimizer states", force=True) self.initialize_optimizer_states() - see_memory_usage("After initializing optimizer states", force=False) + see_memory_usage("After initializing optimizer states", force=True) dist.barrier() if dist.get_rank() == 0: @@ -523,7 +536,7 @@ def _create_param_groups_fp16_flat_cpu_memory(self): aggregate_params_count = 0 - for j, param_group in enumerate(self.optimizer.param_groups): + for j, param_group in enumerate(self.trainable_param_groups): params_in_group = sum([p.partition_numel() for p in param_group['params']]) flat_buffer_size = params_in_group @@ -552,11 +565,12 @@ def _create_param_groups_fp16_flat_cpu_memory(self): torch.empty(1, dtype=self.dtype)) - def _create_fp16_partitions_with_defragmentation(self): + def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): dist.barrier() + param_groups: List[List[Parameter]] = tuple( self._create_fp16_sub_groups(param_group["params"]) - for param_group in self.optimizer.param_groups) + for param_group in fp16_param_groups) # bookkeeping related to param groups for param_group_idx, param_group in enumerate(param_groups): @@ -884,7 +898,6 @@ def initialize_optimizer_states(self): dtype=gradient_dtype, device=self.device) - timers = self.timers timer_names = set() if self.swap_optimizer: @@ -2122,6 +2135,7 @@ def _get_param_groups(self): def _set_param_groups(self, value): self.optimizer.param_groups = value + self.trainable_param_groups = self._get_trainable_parameter_groups() param_groups = property(_get_param_groups, _set_param_groups) diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index f9715bd1dc8e..b3ce68806444 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -1313,3 +1313,63 @@ def test(self, zero_stage): state = optimizer.optimizer.state[param] step_counts.append(state['step']) assert all(step == step_counts[0] for step in step_counts) + + +class TestZeroFrozenWeights(DistributedTest): + world_size = 1 + + def test(self): + config_dict = { + "train_batch_size": 4, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": 3 + } + } + hidden_dim = 10 + + class MyModel(torch.nn.Module): + def __init__(self, hidden_dim): + super(MyModel, self).__init__() + self.l1 = torch.nn.Linear(hidden_dim, hidden_dim) + self.l2 = torch.nn.Linear(hidden_dim, hidden_dim) + self.act = torch.nn.ReLU() + self.cel = torch.nn.CrossEntropyLoss() + + # freeze one fc + self.l2.weight.requires_grad = False + self.l2.bias.requires_grad = False + + def forward(self, x, y): + x = self.l1(x) + x = self.act(x) + x = self.l2(x) + loss = self.cel(x, y) + val = (x, loss) + return val + + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = MyModel(hidden_dim) + + model, _, _, _ = deepspeed.initialize(model=model, + model_parameters=model.parameters(), + config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + dist.barrier() + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + loss = loss[1] + model.backward(loss) + model.step()