-
Notifications
You must be signed in to change notification settings - Fork 4.2k
/
Copy pathbase_optimizer.py
63 lines (48 loc) · 2.6 KB
/
base_optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import torch
from deepspeed.utils import logger
from deepspeed.utils.tensor_fragment import map_to_flat_opt_states
from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank
class DeepSpeedOptimizer(object):
pass
class ZeROOptimizer(DeepSpeedOptimizer):
def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None:
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
optim_sd = torch.load(optim_state_path, weights_only=False)
self._load_global_state(optim_sd)
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
if self.mpu is None:
logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.")
tp_world_size = 1
else:
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
else self.mpu.get_tensor_model_parallel_world_size()
for i, (param_group,
loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])):
# We have an assumption that all params in the same param_group have the same keys
opt_keys = set()
steps = []
lp_groups = getattr(self, lp_groups_name)
for lp in lp_groups[i]:
if lp._hp_mapping is not None:
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
step = lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
tp_world_size)
for key in lp._hp_mapping.get_optim_state_keys():
opt_keys.add(key)
steps.append(step)
hp_param = param_group['params'][0]
assert all(step == steps[0] for step in steps), f"Steps {steps} are not equal"
if steps[0] is not None:
self.optimizer.state[hp_param]['step'] = steps[0]
map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys)
for key, value in loaded_param_group.items():
if key == 'params':
continue
param_group[key] = value