Skip to content

Commit

Permalink
backwards compatability w. v020 ckpts, fix issue with zero-1 ckpts (m…
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra authored Nov 19, 2020
1 parent 9de21b7 commit dce054d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 10 deletions.
2 changes: 2 additions & 0 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def _parse_version(version_str):
sys.modules['deepspeed.pt.deepspeed_utils'] = deepspeed.runtime.utils
setattr(deepspeed.pt, 'deepspeed_config', deepspeed.runtime.config)
sys.modules['deepspeed.pt.deepspeed_config'] = deepspeed.runtime.config
setattr(deepspeed.pt, 'loss_scaler', deepspeed.runtime.fp16.loss_scaler)
sys.modules['deepspeed.pt.loss_scaler'] = deepspeed.runtime.fp16.loss_scaler


def initialize(args,
Expand Down
12 changes: 2 additions & 10 deletions deepspeed/runtime/zero/stage1.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def __init__(self,

# max elems per param group
self.max_elems_per_comm = []
self.legacy_max_elements_per_comm = max_elements_per_comm

# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
Expand Down Expand Up @@ -859,7 +858,6 @@ def _elastic_state_dict(self):
state_dict['zero_stage'] = ZERO_OPTIMIZATION_OPTIMIZER_STATES
state_dict['partition_count'] = self.partition_count
state_dict['num_comm_intervals_per_group'] = self.num_comm_intervals_per_group
state_dict['max_elems_per_comm'] = self.max_elems_per_comm

# Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding = self._get_groups_without_padding(
Expand Down Expand Up @@ -933,10 +931,7 @@ def _restore_from_fp32_weights(self, all_state_dict):
sd['local_sub_partitions_of_fp32_groups'][group_idx]
for sd in all_state_dict
]
if 'max_elems_per_comm' in all_state_dict[0]:
max_elems_per_comm = all_state_dict[0]['max_elems_per_comm'][group_idx]
else:
max_elems_per_comm = self.legacy_max_elements_per_comm
max_elems_per_comm = self.max_elems_per_comm[group_idx]

sub_partition_weights = self._retrieve_group_sub_partition_weights(
all_partition_fp32_weights,
Expand Down Expand Up @@ -1009,10 +1004,7 @@ def _restore_base_optimizer_state(self, state_dict_list):
all_partition_group_states = [
sd['base_optimizer_state'][group_idx] for sd in state_dict_list
]
if 'max_elems_per_comm' in state_dict_list[0]:
max_elems_per_comm = state_dict_list[0]['max_elems_per_comm'][group_idx]
else:
max_elems_per_comm = self.legacy_max_elements_per_comm
max_elems_per_comm = self.max_elems_per_comm[group_idx]
group_optimizer_states = self._retrieve_group_optimizer_states(
all_partition_group_states,
max_elems_per_comm)
Expand Down

0 comments on commit dce054d

Please sign in to comment.