Skip to content

Commit

Permalink
ZeRO3 handling frozen weights] (#2653)
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Jan 17, 2023
1 parent 35575bc commit bf6b980
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 11 deletions.
36 changes: 25 additions & 11 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,15 @@ def __init__(self,
self.sub_group_size = sub_group_size

self.sub_group_to_group_id = {}
see_memory_usage("Before creating fp16 partitions", force=False)
self._create_fp16_partitions_with_defragmentation()

# Trainable parameters
self.trainable_param_groups = self._get_trainable_parameter_groups()

see_memory_usage("Before creating fp16 partitions", force=True)
self._create_fp16_partitions_with_defragmentation(self.trainable_param_groups)
num_fp16_subgroups = len(self.fp16_partitioned_groups_flat)
see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}",
force=False)
force=True)

# Optimizer tensor swapping
if self.swap_optimizer:
Expand Down Expand Up @@ -350,19 +354,28 @@ def __init__(self,
def destroy(self):
self.parameter_offload.destroy()

def _get_trainable_parameter_groups(self):
param_groups = []
for param_group in self.optimizer.param_groups:
trainable_params = {
"params": [p for p in param_group["params"] if p.requires_grad]
}
param_groups.append(trainable_params)
return param_groups

def _setup_for_real_optimizer(self):
see_memory_usage("Before creating fp32 partitions", force=False)
see_memory_usage("Before creating fp32 partitions", force=True)
self._create_fp32_partitions()
see_memory_usage("After creating fp32 partitions", force=False)
see_memory_usage("After creating fp32 partitions", force=True)
dist.barrier()

# To support pipelined optimizer swapping
self._create_next_swappable_fp32_groups()

see_memory_usage("Before initializing optimizer states", force=False)
see_memory_usage("Before initializing optimizer states", force=True)

self.initialize_optimizer_states()
see_memory_usage("After initializing optimizer states", force=False)
see_memory_usage("After initializing optimizer states", force=True)
dist.barrier()

if dist.get_rank() == 0:
Expand Down Expand Up @@ -523,7 +536,7 @@ def _create_param_groups_fp16_flat_cpu_memory(self):

aggregate_params_count = 0

for j, param_group in enumerate(self.optimizer.param_groups):
for j, param_group in enumerate(self.trainable_param_groups):
params_in_group = sum([p.partition_numel() for p in param_group['params']])

flat_buffer_size = params_in_group
Expand Down Expand Up @@ -552,11 +565,12 @@ def _create_param_groups_fp16_flat_cpu_memory(self):
torch.empty(1,
dtype=self.dtype))

def _create_fp16_partitions_with_defragmentation(self):
def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups):
dist.barrier()

param_groups: List[List[Parameter]] = tuple(
self._create_fp16_sub_groups(param_group["params"])
for param_group in self.optimizer.param_groups)
for param_group in fp16_param_groups)

# bookkeeping related to param groups
for param_group_idx, param_group in enumerate(param_groups):
Expand Down Expand Up @@ -884,7 +898,6 @@ def initialize_optimizer_states(self):
dtype=gradient_dtype,
device=self.device)

timers = self.timers
timer_names = set()

if self.swap_optimizer:
Expand Down Expand Up @@ -2122,6 +2135,7 @@ def _get_param_groups(self):

def _set_param_groups(self, value):
self.optimizer.param_groups = value
self.trainable_param_groups = self._get_trainable_parameter_groups()

param_groups = property(_get_param_groups, _set_param_groups)

Expand Down
60 changes: 60 additions & 0 deletions tests/unit/runtime/zero/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,3 +1313,63 @@ def test(self, zero_stage):
state = optimizer.optimizer.state[param]
step_counts.append(state['step'])
assert all(step == step_counts[0] for step in step_counts)


class TestZeroFrozenWeights(DistributedTest):
world_size = 1

def test(self):
config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": 3
}
}
hidden_dim = 10

class MyModel(torch.nn.Module):
def __init__(self, hidden_dim):
super(MyModel, self).__init__()
self.l1 = torch.nn.Linear(hidden_dim, hidden_dim)
self.l2 = torch.nn.Linear(hidden_dim, hidden_dim)
self.act = torch.nn.ReLU()
self.cel = torch.nn.CrossEntropyLoss()

# freeze one fc
self.l2.weight.requires_grad = False
self.l2.bias.requires_grad = False

def forward(self, x, y):
x = self.l1(x)
x = self.act(x)
x = self.l2(x)
loss = self.cel(x, y)
val = (x, loss)
return val

with deepspeed.zero.Init(config_dict_or_path=config_dict):
model = MyModel(hidden_dim)

model, _, _, _ = deepspeed.initialize(model=model,
model_parameters=model.parameters(),
config=config_dict)
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
dist.barrier()
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
loss = loss[1]
model.backward(loss)
model.step()

0 comments on commit bf6b980

Please sign in to comment.