diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 7a18f98a49e8..50049a2a1996 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -238,7 +238,8 @@ def init_inference(model, moe_experts=1, moe_type='standard', args=None, - enable_cuda_graph=False): + enable_cuda_graph=False, + save_mp_checkpoint_path=None): """Initialize the DeepSpeed InferenceEngine. Arguments: @@ -304,6 +305,7 @@ def init_inference(model, moe_experts, moe_type, args, - enable_cuda_graph) + enable_cuda_graph, + save_mp_checkpoint_path) return engine diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 8af9a12c2809..db9efb19dcb1 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -14,7 +14,6 @@ from ..runtime.state_dict_factory import SDLoaderFactory from ..runtime.weight_quantizer import WeightQuantization from ..module_inject.replace_module import replace_transformer_layer -from ..utils import logger from ..comm.comm import init_distributed from ..pipe import PipelineModule from ..moe.utils import has_moe_layers @@ -50,7 +49,8 @@ def __init__(self, moe_experts=1, moe_type='standard', config=None, - enable_cuda_graph=False): + enable_cuda_graph=False, + save_mp_checkpoint_path=None): """ Args: model: torch.nn.Module @@ -130,7 +130,8 @@ def __init__(self, moe_experts, moe_type, training_mp_size, - self.checkpoint if replace_with_kernel_inject else None) + self.checkpoint if replace_with_kernel_inject else None, + save_mp_checkpoint_path=save_mp_checkpoint_path) elif replace_method == 'auto': self._apply_injection_policy( return_tuple=return_tuple, @@ -139,12 +140,17 @@ def __init__(self, moe_experts=moe_experts, moe_type=moe_type, training_mp_size=training_mp_size, - checkpoint_dir=self.checkpoint if replace_with_kernel_inject else None) + checkpoint_dir=self.checkpoint if replace_with_kernel_inject else None, + save_mp_checkpoint_path=save_mp_checkpoint_path) device = torch.cuda.current_device() - logger.info(f"Place model to device: {device}") self.module.to(device) + if self.mp_world_size > 1: + _rng_state = torch.cuda.get_rng_state().to(torch.cuda.current_device()) + dist.broadcast(_rng_state, 0) + torch.cuda.set_rng_state(_rng_state.cpu()) + if self.mp_world_size > 1: self.model_orig_fwd = self.module.forward self.module.forward = self.forward @@ -226,9 +232,9 @@ def _validate_args(self, mpu): for method in methods: if not hasattr(mpu, method): raise ValueError(f"mpu is missing {method}") - if self.checkpoint is not None and not isinstance(self.checkpoint, str): + if self.checkpoint is not None and not isinstance(self.checkpoint, (str, dict)): raise ValueError( - f"checkpoint must be None or a str, got {type(self.checkpoint)}") + f"checkpoint must be None, str or dict, got {type(self.checkpoint)}") supported_dtypes = [None, torch.half, torch.int8, torch.float] if self.dtype not in supported_dtypes: @@ -315,32 +321,37 @@ def _apply_injection_policy(self, moe_experts=1, moe_type='standard', training_mp_size=1, - checkpoint_dir=None): + checkpoint_dir=None, + save_mp_checkpoint_path=False): checkpoint = SDLoaderFactory.get_sd_loader_json( - checkpoint_dir) if checkpoint_dir is not None else None - replace_transformer_layer(client_module, - self.module, - triangular_masking=self.triangular_masking, - policy=injection_policy, - mp_size=self.mp_world_size, - mp_group=self.mp_group, - ep_group=self.ep_group, - expert_mp_group=self.expert_mp_group, - config=self.config, - fp16=(self.dtype == torch.half), - training=False, - return_tuple=return_tuple, - quantize=(self.dtype == torch.int8), - quantize_settings=(self.quantization_scales, - self.quantize_merge_count, - self.mlp_extra_grouping, - self.quantize_groups), - replace_with_kernel_inject=replace_with_kernel_inject, - moe=moe, - moe_experts=moe_experts, - moe_type=moe_type, - training_mp_size=training_mp_size, - checkpoint=checkpoint) + checkpoint_dir, + self.checkpoint_engine) if checkpoint_dir is not None else None + replace_transformer_layer( + client_module, + self.module, + triangular_masking=self.triangular_masking, + policy=injection_policy, + mp_size=self.mp_world_size, + mp_group=self.mp_group, + ep_group=self.ep_group, + expert_mp_group=self.expert_mp_group, + config=self.config, + fp16=(self.dtype == torch.half), + training=False, + return_tuple=return_tuple, + quantize=(self.dtype == torch.int8), + quantize_settings=(self.quantization_scales, + self.quantize_merge_count, + self.mlp_extra_grouping, + self.quantize_groups), + replace_with_kernel_inject=replace_with_kernel_inject, + moe=moe, + moe_experts=moe_experts, + moe_type=moe_type, + training_mp_size=training_mp_size, + checkpoint_dict=checkpoint, + save_mp_checkpoint_path=save_mp_checkpoint_path, + ) def _get_all_ckpt_names(self, checkpoints_path, tag): ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, @@ -380,8 +391,7 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): ckpt_list = self._get_all_ckpt_names(load_dir, tag) sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine) else: - sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir, - self.checkpoint_engine) + sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir) if type(sd_loader) is list: self.sd = torch.load(sd_loader[0], map_location='cpu') diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index 1d05b6e853f6..e0f44675dfd7 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -2,9 +2,10 @@ import deepspeed.ops.transformer as transformer_inference from ..runtime.zero import GatheredParameters from .layers import LinearLayer, Normalize, EmbeddingLayer +import torch -def load_model_with_checkpoint(r_module, sd, mp_replace): +def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type, rank=0): error_msgs = [] def transpose(data): @@ -29,33 +30,76 @@ def load(module, prefix): module.bias = mp_replace.copy(module.bias.data, sd[prefix + 'bias']) def load_transformer_layer(module, prefix): - module.norm_w.data.copy_(sd[prefix + 'input_layernorm.' + 'weight']) - module.norm_b.data.copy_(sd[prefix + 'input_layernorm.' + 'bias']) - module.attention.attn_qkvw = mp_replace.copy( - module.attention.attn_qkvw.data, - transpose(sd[prefix + 'self_attention.query_key_value.' + 'weight'])) - module.attention.attn_qkvb = mp_replace.copy( - module.attention.attn_qkvb.data, - sd[prefix + 'self_attention.query_key_value.' + 'bias']) - module.attention.attn_ow = mp_replace.copy( - module.attention.attn_ow.data, - transpose(sd[prefix + 'self_attention.dense.' + 'weight'])) - module.attention.attn_ob = mp_replace.copy( - module.attention.attn_ob.data, - sd[prefix + 'self_attention.dense.' + 'bias']) - module.mlp.attn_nw.data.copy_(sd[prefix + 'post_attention_layernorm.' + - 'weight']) - module.mlp.attn_nb.data.copy_(sd[prefix + 'post_attention_layernorm.' + 'bias']) - module.mlp.inter_w = mp_replace.copy( - module.mlp.inter_w.data, - transpose(sd[prefix + 'mlp.dense_h_to_4h.' + 'weight'])) - module.mlp.inter_b = mp_replace.copy(module.mlp.inter_b.data, - sd[prefix + 'mlp.dense_h_to_4h.' + 'bias']) - module.mlp.output_w = mp_replace.copy( - module.mlp.output_w.data, - transpose(sd[prefix + 'mlp.dense_4h_to_h.' + 'weight'])) - module.mlp.output_b = mp_replace.copy(module.mlp.output_b.data, - sd[prefix + 'mlp.dense_4h_to_h.' + 'bias']) + if ckpt_type == "tp": + + def load_parameters(module, prefix): + for n, p in module.named_parameters(): + if len(n.split('.')) == 1: + src_shape = sd[prefix + n].shape + dst_shape = p.shape + + if (len(src_shape) == 2 and len(dst_shape) == 2): + if src_shape[0] == dst_shape[0] and src_shape[ + 1] == dst_shape[1]: + p.data.copy_(sd[prefix + n]) + else: + if src_shape[0] != dst_shape[0]: + weight_split = torch.split( + sd[prefix + n], + dst_shape[0], + dim=0)[rank].to( + torch.cuda.current_device()).contiguous() + else: + weight_split = torch.split( + sd[prefix + n], + dst_shape[1], + dim=1)[rank].to( + torch.cuda.current_device()).contiguous() + p.data.copy_(weight_split.contiguous()) + else: + if src_shape[0] == dst_shape[0]: + p.data.copy_(sd[prefix + n]) + else: + bias_split = torch.split( + sd[prefix + n], + dst_shape[-1])[rank].to( + torch.cuda.current_device()).contiguous() + p.data.copy_(bias_split) + + load_parameters(module, prefix) + for n, child in module.named_children(): + load_parameters(child, prefix + n + '.') + else: + module.norm_w.data.copy_(sd[prefix + 'input_layernorm.' + 'weight']) + module.norm_b.data.copy_(sd[prefix + 'input_layernorm.' + 'bias']) + module.attention.attn_qkvw = mp_replace.copy( + module.attention.attn_qkvw.data, + transpose(sd[prefix + 'self_attention.query_key_value.' + 'weight'])) + module.attention.attn_qkvb = mp_replace.copy( + module.attention.attn_qkvb.data, + sd[prefix + 'self_attention.query_key_value.' + 'bias']) + module.attention.attn_ow = mp_replace.copy( + module.attention.attn_ow.data, + transpose(sd[prefix + 'self_attention.dense.' + 'weight'])) + module.attention.attn_ob = mp_replace.copy( + module.attention.attn_ob.data, + sd[prefix + 'self_attention.dense.' + 'bias']) + module.mlp.attn_nw.data.copy_(sd[prefix + 'post_attention_layernorm.' + + 'weight']) + module.mlp.attn_nb.data.copy_(sd[prefix + 'post_attention_layernorm.' + + 'bias']) + module.mlp.inter_w = mp_replace.copy( + module.mlp.inter_w.data, + transpose(sd[prefix + 'mlp.dense_h_to_4h.' + 'weight'])) + module.mlp.inter_b = mp_replace.copy( + module.mlp.inter_b.data, + sd[prefix + 'mlp.dense_h_to_4h.' + 'bias']) + module.mlp.output_w = mp_replace.copy( + module.mlp.output_w.data, + transpose(sd[prefix + 'mlp.dense_4h_to_h.' + 'weight'])) + module.mlp.output_b = mp_replace.copy( + module.mlp.output_b.data, + sd[prefix + 'mlp.dense_4h_to_h.' + 'bias']) layer_policies = { nn.Linear: load, @@ -95,6 +139,9 @@ def load_module_recursive(module, prefix='', level=0): dtype=child.weight.dtype, eps=child.eps) setattr(module, name, child) + elif child.__class__ is nn.Linear: + child = LinearLayer(weight=child.weight, bias=child.bias) + setattr(module, name, child) else: ds_id = None if hasattr(child.weight, 'ds_id'): @@ -107,9 +154,10 @@ def load_module_recursive(module, prefix='', level=0): layer_policies[child.__class__](child, prefix + name + '.') else: - load_module_recursive(child, - prefix if level == 0 else prefix + name + '.', - level + 1) + load_module_recursive( + child, + prefix if level == 0 and ckpt_type == 'pp' else prefix + name + '.', + level + 1) load_module_recursive(r_module) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 151abfacf004..b9e9d90c1778 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -1,3 +1,4 @@ +import os import torch import tqdm import deepspeed @@ -11,6 +12,7 @@ from ..runtime.zero import GatheredParameters from .layers import LinearAllreduce, LinearLayer from .load_checkpoint import load_model_with_checkpoint +import time class ReplaceWithTensorSlicing: @@ -117,6 +119,21 @@ def copy(self, dst, src): return torch.nn.parameter.Parameter(dst, requires_grad=False) +def get_transformer_name(replaced_module): + from .replace_policy import supported_models + from torch.nn import ModuleList + transformer_name = '' + for n, c in replaced_module.named_children(): + if c.__class__ in supported_models: + transformer_name += n + '.' + for name, child in c.named_children(): + if child.__class__ is ModuleList: + transformer_name += name + break + break + return transformer_name + + def replace_transformer_layer(orig_layer_impl, model, policy=None, @@ -144,7 +161,8 @@ def replace_transformer_layer(orig_layer_impl, moe=False, moe_experts=1, moe_type='standard', - checkpoint=None): + checkpoint_dict=None, + save_mp_checkpoint_path=None): """ Replace bert-style transformer layers with DeepSpeed's transformer layer Arguments: orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, @@ -764,14 +782,86 @@ def replace_fn(child, _policy, layer_id=0): replace_fn=replace_fn, _replace_policy=policy) - if checkpoint is not None: - pbar = tqdm.tqdm(total=len(checkpoint), - desc=f"Loading {len(checkpoint)} checkpoint shards") - for i in range(len(checkpoint)): - if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: - pbar.update(1) - sd = torch.load(checkpoint[i], map_location='cpu') - load_model_with_checkpoint(replaced_module, sd, mp_replace) + if checkpoint_dict is not None: + start_time = time.time() + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + checkpoint = checkpoint_dict['checkpoints'] + ckpt_type = checkpoint_dict.get('parallelization', 'pp') + ckpt_mp_size = checkpoint_dict.get('mp_size', mp_size) + base_dir = checkpoint_dict.get('base_dir', '') + + if ckpt_type == 'pp': + pbar = tqdm.tqdm(total=len(checkpoint), + desc=f"Loading {len(checkpoint)} checkpoint shards") + for i in range(len(checkpoint)): + if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: + pbar.update(1) + sd = torch.load(checkpoint[i], map_location='cpu') + load_model_with_checkpoint(replaced_module, sd, mp_replace, ckpt_type) + else: + num_checkpoints = len(checkpoint) // ckpt_mp_size + assert world_size >= ckpt_mp_size,\ + "Currently, merging checkpoints is not supported (when world_size is smaller than #checkpoints)!" + checkpoint_stride = world_size // ckpt_mp_size + pbar = tqdm.tqdm(total=num_checkpoints, + desc=f"Loading {num_checkpoints} checkpoint shards") + for i in range(num_checkpoints): + if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: + pbar.update(1) + + ckpt_index = i * ckpt_mp_size + (rank // checkpoint_stride) + ckpt_file = os.path.join( + base_dir, + checkpoint[ckpt_index]) if base_dir else checkpoint[ckpt_index] + sd = torch.load(ckpt_file, map_location='cpu') + load_model_with_checkpoint(replaced_module, + sd, + mp_replace, + ckpt_type, + rank % (world_size // ckpt_mp_size)) + print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") + + if save_mp_checkpoint_path is not None: + from collections import OrderedDict + import json + + ckpt_name = checkpoint_dict['type'] + if dist.is_initialized(): + dist.barrier() + transformer_name = get_transformer_name(replaced_module) + non_tp_ckpt_name = f'{ckpt_name}-non-tp.pt' + ckpt_files = [non_tp_ckpt_name] * world_size + if not dist.is_initialized() or dist.get_rank() == 0: + print("Saving tp-sharded checkpoints") + torch.save( + OrderedDict({ + k: v + for k, + v in dict(replaced_module.state_dict()).items() + if transformer_name not in k + }), + f'{save_mp_checkpoint_path}/{non_tp_ckpt_name}') + ckpt_files += [f'{ckpt_name}-tp_{r:0>2d}.pt' for r in range(world_size)] + config = json.dumps({ + 'type': ckpt_name, + 'base_dir': f'{save_mp_checkpoint_path}', + 'checkpoints': ckpt_files, + 'version': 1.0, + 'parallelization': 'tp', + 'mp_size': world_size + }) + with open(f"{save_mp_checkpoint_path}/{ckpt_name}_ds-inference_config.json", + "w") as cfg: + cfg.write(config) + torch.save( + OrderedDict({ + k: v + for k, + v in dict(replaced_module.state_dict()).items() if transformer_name in k + }), + f'{save_mp_checkpoint_path}/{ckpt_name}-tp_{rank:0>2d}.pt') + return replaced_module diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index eeb6d613969b..3d5c53275e33 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -4,6 +4,8 @@ from torch.nn.parameter import Parameter from packaging import version as pkg_version +supported_models = {None} + class DSPolicy(ABC): def __init__(self, @@ -329,6 +331,9 @@ def __init__(self, client_module, inference=True): try: import transformers BLOOMLayerPolicy._orig_layer_class = transformers.models.bloom.modeling_bloom.BloomBlock + global supported_models + supported_models.update( + {transformers.models.bloom.modeling_bloom.BloomModel}) except: BLOOMLayerPolicy._orig_layer_class = None diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index 81762c10b014..df65fb317e9b 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -191,6 +191,7 @@ def split_tensor_along_last_dim(tensor, return tensor_list def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): + alibi = alibi.to(torch.cuda.current_device()) head_dim = hidden_size_per_partition // num_attention_heads_per_partition new_tensor_shape = mixed_x_layer.size()[:-1] + ( num_attention_heads_per_partition, diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index 6097e8baa004..0b720ff471f3 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -20,17 +20,23 @@ class SDLoaderFactory: @staticmethod def get_sd_loader_json(json_file, checkpoint_engine): - with open(json_file) as f: - data = json.load(f) - sd_type = data['type'] - ckpt_list = data['checkpoints'] - version = data['version'] - if 'BLOOM' in sd_type or 'Bloom' in sd_type: - return ckpt_list - return SDLoaderFactory.get_sd_loader(ckpt_list, - checkpoint_engine, - sd_type, - version) + if isinstance(json_file, str): + with open(json_file) as f: + data = json.load(f) + else: + assert isinstance(json_file, dict) + data = json_file + sd_type = data['type'] + ckpt_list = data['checkpoints'] + version = data['version'] + ckpt_type = data.get('parallelization', 'pp') + mp_size = data.get('mp_size', 0) + if 'bloom' in sd_type.lower(): + return data + return SDLoaderFactory.get_sd_loader(ckpt_list, + checkpoint_engine, + sd_type, + version) @staticmethod def get_sd_loader(ckpt_list, checkpoint_engine, sd_type='Megatron', version=None):