Skip to content

Commit

Permalink
DeepSpeedCheckpoint: support custom final ln idx (#5506)
Browse files Browse the repository at this point in the history
till today only last layer (idx=-1) was considered using
FINAL_LAYER_NORM_INDEX which is set to -1.
this PR allows the user to pass custom value for model where this
default value does not apply.
see example for usage in HabanaAI/Megatron-DeepSpeed fork repository:

https://github.com/HabanaAI/Megatron-DeepSpeed/blob/c9feb8cacabc6dd4da4266cff08db555a21122e2/tools/verify_checkpoint_non_tp_consistency.py#L296

---------

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
4 people authored May 28, 2024
1 parent 4deb40d commit 2fc702e
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions deepspeed/checkpoint/deepspeed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team

import os
import re
from typing import Dict
import torch

Expand All @@ -21,6 +22,7 @@
ARGS_KEY = 'args'
CHECKPOINT_INFO_KEY = 'checkpoint_info'
ITERATION_KEY = 'iteration'
LAYER_FILE_PREFIX_PATTERN = r'layer_(\d+)-model_.*'

SEQUENTIAL_LAYERS = [
'input_layernorm.weight', 'input_layernorm.bias', 'self_attention.dense.bias', 'post_attention_layernorm.weight',
Expand All @@ -32,7 +34,13 @@

class DeepSpeedCheckpoint(object):

def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
def __init__(self,
dir,
tp_degree=None,
pp_degree=None,
dp_degree=None,
final_layer_norm_idx=FINAL_LAYER_NORM_INDEX):
self.final_layer_norm_idx = final_layer_norm_idx
self.dir = dir

pipeline_parallel = len(get_files_with_prefix(get_files(dir), LAYER_FILE_PREFIX)) > 0
Expand Down Expand Up @@ -73,7 +81,7 @@ def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
self.pp_to_transformer_map = self._build_pp_transformer_map()
self.transformer_file_map = self._build_transformer_file_map()
self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX)
self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX)
self.tp_to_final_norm_map = self._build_tp_other_layer_map(self.final_layer_norm_idx)
self._build_global_state()

def is_change_tp_degree(self):
Expand Down Expand Up @@ -125,7 +133,7 @@ def get_embedding_layer_id(self):
return self.layer_keys[EMBEDDING_LAYER_INDEX]

def get_final_norm_layer_id(self):
return self.layer_keys[FINAL_LAYER_NORM_INDEX]
return self.layer_keys[self.final_layer_norm_idx]

def get_iteration(self):
if not ITERATION_KEY in self.global_state:
Expand Down Expand Up @@ -214,7 +222,7 @@ def get_2d_parallel_files(self, tp_index: int, pp_index: int) -> list:
def _build_pp_transformer_map(self):
data_map = {}
if self.pp_degree > 0:
transformer_layers = self.layer_keys[1:-1]
transformer_layers = self.layer_keys[1:self.final_layer_norm_idx]
layers_per_pp = len(transformer_layers) // self.pp_degree
data_map = {
i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp]
Expand All @@ -229,7 +237,7 @@ def _dump_mapping(self, data_map, map_tag=None):
print(f'{k} = {v}')

def _build_transformer_file_map(self):
transformer_layer_keys = self.layer_keys[1:-1]
transformer_layer_keys = self.layer_keys[1:self.final_layer_norm_idx]
file_map = {}
# XXX: this is not guaranteed
layers_per_pp = 1
Expand All @@ -238,7 +246,7 @@ def _build_transformer_file_map(self):
#print(f"{transformer_layer_keys} {layers_per_pp}")
for key_index, layer_key in enumerate(transformer_layer_keys):
pp_index = key_index // layers_per_pp
layer_files = get_files_with_prefix(self.layer_files, layer_key)
layer_files = get_files_with_prefix(self.layer_files, layer_key + '-')
layer_file_partitions = partition_data(layer_files, self.tp_degree)
for tp_index in range(self.tp_degree):
map_key = (tp_index, pp_index)
Expand All @@ -263,11 +271,13 @@ def validate_files(self):

def _get_layer_keys(self):
key_set = set()
key_len = len(LAYER_FILE_PREFIX) + 2
for file_path in self.layer_files:
_, fname = os.path.split(file_path)
key_set.add(fname[:key_len])
return sorted(list(key_set))
layer_id = re.search(LAYER_FILE_PREFIX_PATTERN, fname).group(1)
key_set.add(layer_id)
sorted_ids = sorted(list(key_set), key=int)
layer_keys = [LAYER_FILE_PREFIX + str(layer_id) for layer_id in sorted_ids]
return layer_keys

def _merge_state_dicts(self, sd_list):
merged_sd = {}
Expand Down

0 comments on commit 2fc702e

Please sign in to comment.