From 5bf8655be624b3aeda799b80fddd220213491b04 Mon Sep 17 00:00:00 2001 From: hyunwoongko Date: Mon, 30 Aug 2021 18:31:21 +0900 Subject: [PATCH] Add tensor model parallel inference and training with GPT-Neo model --- .../models/gpt_neo/modeling_gpt_neo.py | 88 +- src/transformers/parallelization_utils.py | 1257 +++++++++++++++++ tests/parallelism/__init__.py | 0 tests/parallelism/test_gpt_neo_inference.py | 70 + ...t_gpt_neo_inference_with_model_parallel.py | 86 ++ tests/parallelism/test_gpt_neo_training.py | 146 ++ ...st_gpt_neo_training_with_model_parallel.py | 157 ++ 7 files changed, 1803 insertions(+), 1 deletion(-) create mode 100644 src/transformers/parallelization_utils.py create mode 100644 tests/parallelism/__init__.py create mode 100644 tests/parallelism/test_gpt_neo_inference.py create mode 100644 tests/parallelism/test_gpt_neo_inference_with_model_parallel.py create mode 100644 tests/parallelism/test_gpt_neo_training.py create mode 100644 tests/parallelism/test_gpt_neo_training_with_model_parallel.py diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 05e5b1ce281717..06d5511b84def1 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -33,6 +33,13 @@ SequenceClassifierOutputWithPast, ) from ...modeling_utils import PreTrainedModel +from ...parallelization_utils import ( + ColumnParallelLinear, + Layer, + ParallelizationMixin, + ParallelPolicy, + RowParallelLinear, +) from ...utils import logging from .configuration_gpt_neo import GPTNeoConfig @@ -579,7 +586,7 @@ def forward( return outputs # hidden_states, present, (attentions, cross_attentions) -class GPTNeoPreTrainedModel(PreTrainedModel): +class GPTNeoPreTrainedModel(PreTrainedModel, ParallelizationMixin): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. @@ -608,6 +615,24 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def parallelize( + self, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + ): + """ + Parallelize model by given model parallel sizes + + Args: + tensor_model_parallel_size (int): tensor model parallel size + pipeline_model_parallel_size (int): pipeline model parallel size + """ + self._parallelize( + policies=[GPTNeoParallelPolicy], + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + ) + GPT_NEO_START_DOCSTRING = r""" @@ -1144,3 +1169,64 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +class GPTNeoParallelPolicy(ParallelPolicy): + @staticmethod + def replace_arguments(config, world_size): + return { + # 1. reduce hidden size + "attn.attention.embed_dim": config.hidden_size // world_size, + # 2. reduce number of heads + "attn.attention.num_heads": config.num_heads // world_size, + } + + @staticmethod + def attn_qkv(): + return [ + Layer( + weight="attn.attention.q_proj.weight", + replace=ColumnParallelLinear, + ), + Layer( + weight="attn.attention.k_proj.weight", + replace=ColumnParallelLinear, + ), + Layer( + weight="attn.attention.v_proj.weight", + replace=ColumnParallelLinear, + ), + ] + + @staticmethod + def attn_out(): + return [ + Layer( + weight="attn.attention.out_proj.weight", + replace=RowParallelLinear, + ), + ] + + @staticmethod + def mlp_in(): + return [ + Layer( + weight="mlp.c_fc.weight", + bias="mlp.c_fc.bias", + replace=ColumnParallelLinear, + ), + ] + + @staticmethod + def mlp_out(): + return [ + Layer( + weight="mlp.c_proj.weight", + bias="mlp.c_proj.bias", + replace=RowParallelLinear, + ), + ] + + @staticmethod + def original_layer_class(): + return GPTNeoBlock diff --git a/src/transformers/parallelization_utils.py b/src/transformers/parallelization_utils.py new file mode 100644 index 00000000000000..9e1cf9f07429df --- /dev/null +++ b/src/transformers/parallelization_utils.py @@ -0,0 +1,1257 @@ +# coding=utf-8 +# Copyright 2021 TUNiB Inc, NVIDIA CORPORATION and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from abc import ABC, abstractmethod +from contextlib import suppress +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple, Type, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.autograd import Function + + +""" +Integration with Parallelformers +https://github.com/tunib-ai/parallelformers +""" + + +@dataclass +class Layer: + r""" + Dataclass used to describe a layer in the parallel policy object + + Attributes: + weight and bias (str): the names of the weight and bias tensors, respectively. You can use the syntax + such as `[ ]` or `.` to the tensor names. `.` is used as accessors in common programming languages and `[ ]` + is used to access elements in nn.ModuleList. + + n_fused (int): the number of areas used in fused attention layers. For example, GPT2 and TransfoXL have + fused attention layers that consist of query, key and value. These layers should not be simply chunked by + the number of GPUs. Instead, they should be divided into the query, key and value areas first. + + replace (Any): the layer that you want to replace an existing layer with. The parallelization process + by the tensor slicing method involves All-Reduce operations to collect tensors from all GPUs. + So, we need to insert some layer like RoeParallelLinear or ColumnParallelLinear to replace the existing layer. + + reversed (bool): this attribute is used to indicate whether tensors are reversed or not. Some models such as + GPT1 and GPT2 use the transformers.modeling_utils.Conv1D layer instead of the nn.Linear layer. + These layers store weight and bias tensors reversed. + + ignore_check (bool): this attribute is used when you want to ignore errors in case the layers do not exist. + Some models like Bert, Roberta and Electra have only encoder layers. but for Huggingface, + these models are also designed to be able to used as decoders. In these models, + some layers may or may not be created depending on the configurations. + In this case, you can use ignore_check option to ignore errors even if the layers do not always exist. + """ + + weight: str = None + bias: str = None + n_fused: int = None + replace: Any = None + reversed: bool = None + ignore_check: bool = False + + +class ParallelPolicy(ABC): + """ + Parallelization policy to apply parallelism per model. + You can check more details here: https://github.com/tunib-ai/parallelformers/blob/main/POLICY.md + TODO: It would be a great to write a description of this object in the Huggingface docs and replace the link above with the Huggingface docs.. + + Args: + layer (nn.Module): The layer to apply the parallel policy + + References: + The design of the ParallelPolicy class is inspired by Microsoft DeepSpeed. + https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py + """ + + def __init__(self, layer: nn.Module) -> None: + self.layer = layer + + @staticmethod + def replace_arguments(config, world_size: int) -> Dict: + """ + ParallelPolicy for argument replacement. + + Args: + config (PretrainedConfig): pretrained config object + world_size (int): world size of tensor model parallelization + + Returns: + Dict: Dictionary for argument replacement. + + Notes: + The format of the dictionary object is as follows. + + dict: + "param_name_1": reset_value_1, + "param_name_2": reset_value_2, + "param_name_3": reset_value_3, + ... + "param_name_n": reset_value_n + """ + return {} + + @staticmethod + def attn_qkv() -> List: + """ + Attention query, key, value projection layer + + Returns: + List[Layer]: List of layer object + """ + return [] + + @staticmethod + def attn_out() -> List: + """ + Attention output projection layer + + Returns: + List[Layer]: List of layer object + """ + return [] + + @staticmethod + def mlp_in() -> List: + """ + h -> 4h mlp layer + + Returns: + List[Layer]: List of layer object + """ + return [] + + @staticmethod + def mlp_out() -> List: + """ + 4h -> h mlp layer + + Returns: + List[Layer]: List of layer object + """ + return [] + + @staticmethod + @abstractmethod + def original_layer_class() -> Type[nn.Module]: + """ + Class to apply the parallel policy + e.g. BertLayer, GPT2Block, BartEncoderLayer, ... + + Returns: + Type[nn.Module]: original layer class + """ + raise NotImplementedError + + def _add_parameters_to_dict(self, layer, param, param_dict, attr_dict): + orig_layer = None + + if param is not None: + param_name = f"layer.{param}" + orig_layer_name = ".".join(param.split(".")[:-1]) + orig_layer = rgetattr(self, f"layer.{orig_layer_name}") + + if rhasattr(self, param_name): + param_dict[param_name] = rgetattr(self, param_name) + attr_dict[param_name] = (layer.n_fused, layer.reversed) + + elif not layer.ignore_check: + raise Exception(f"'{self.original_layer_class().__qualname__}' object has no attribute '{param}'.") + + return orig_layer + + def preprocess( + self, + layers: List[Layer], + model_parallel_group, + ) -> Tuple[Dict, Dict, Dict, Dict]: + """ + Preprocess policy object to replace tensors + + Args: + layers (List[Layer]): list of layers in the policy object + model_parallel_group: model parallel group + + Returns: + Tuple[Dict, Dict, Dict, Dict]: + Tuple of dictionaries of parameters and attributes required for tensor slicing + """ + weight_param_dict, bias_param_dict = {}, {} + weight_attr_dict, bias_attr_dict = {}, {} + + for layer in layers: + orig_layer_from_w = self._add_parameters_to_dict( + layer, + layer.weight, + weight_param_dict, + weight_attr_dict, + ) + + orig_layer_from_b = self._add_parameters_to_dict( + layer, + layer.bias, + bias_param_dict, + bias_attr_dict, + ) + + if layer.replace is not None: + orig_layer = None + if orig_layer_from_w is not None: + orig_layer = orig_layer_from_w + elif orig_layer_from_b is not None: + orig_layer = orig_layer_from_b + + if orig_layer is not None: + orig_layer.__class__ = layer.replace + orig_layer.inject_model_parallel_group(model_parallel_group) + + return weight_param_dict, bias_param_dict, weight_attr_dict, bias_attr_dict + + +class ParallelizationMixin: + """ + ParallelizationMixin has 5 distributed process group + 1) tensor parallel group + 2) pipeline parallel group + 3) data parallel group + 4) embedding parallel group (not supported in draft version) + 5) model parallel group (for backward compatibility with DeepSpeed) + + Reference: + The design of the ParallelizationMixin class is inspired by Nvidia Megatron-LM. + Most of the code was copied from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/initialize.py. + """ + + _TENSOR_MODEL_PARALLEL_GROUP = None + _TENSOR_MODEL_PARALLEL_RANK = None + _TENSOR_MODEL_PARALLEL_WORLD_SIZE = None + _PIPELINE_MODEL_PARALLEL_GROUP = None + _PIPELINE_GLOBAL_RANKS = None + _PIPELINE_MODEL_PARALLEL_RANK = None + _PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None + _DATA_PARALLEL_GROUP = None + _EMBEDDING_GROUP = None + _MODEL_PARALLEL_GROUP = None + + def _is_uninitialized(self): + """Useful for code segments that may be accessed with or without mpu initialization""" + return self._DATA_PARALLEL_GROUP is None + + def _initialize_model_parallel( + self, + tensor_model_parallel_size_: int = 1, + pipeline_model_parallel_size_: int = 1, + ): + """ + Initialize all distributed parallel groups + + Args: + tensor_model_parallel_size_ (int): the number of GPUs used to parallelize model tensor + pipeline_model_parallel_size_ (int): the number of GPUs used to parallelize model pipeline + + Notes: + Let's say we have a total of 16 GPUs denoted g0 ... g15 and we use 2 GPUs to parallelize the model tensor, + and 4 GPUs to parallelize the model parallel. The present method will create 8 tensor model-parallel group, + 4 pipeline model parallel groups and 8 data parallel groups as: + + - width: 4 pipeline parallel group + [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] + - height: 8 tensor parallel group + [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] + - depth: 8 data parallel group + [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] + + [g02, g06, g10, g14] + / | / | + [g00, g04, g08, g12] | + | | | | + 3D parallel | [g03, g07, g11, g15] + | / | / + [g01, g05, g09, g13] + + +---+ +---------+ +---------+ +---------+ +---------+ +---+ + tensor |g00| | g00 | | g04 | | g08 | | g12 | |g12| + data |---| +---------+ +---------+ +---------+ +---------+ |---| ===> forward + tensor |g01| | g01 | | g05 | | g09 | | g13 | |g13| + +---+ +---------+ +---------+ +---------+ +---------+ +---+ + emb pipeline pipeline pipeline pipeline emb + + +---+ +---------+ +---------+ +---------+ +---------+ +---+ + tensor |g02| | g02 | | g06 | | g10 | | g12 | |g14| + data |---| +---------+ +---------+ +---------+ +---------+ |---| ===> forward + tensor |g03| | g03 | | g07 | | g11 | | g15 | |g15| + +---+ +---------+ +---------+ +---------+ +---------+ +---+ + emb pipeline pipeline pipeline pipeline emb + """ + + assert dist.is_initialized() + world_size = dist.get_world_size() + current_rank = dist.get_rank() + + # 1. Ensure model parallel size must be smaller than world size + tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size) + pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size) + total_model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size + data_parallel_size = world_size // total_model_parallel_size + assert ( + world_size % total_model_parallel_size == 0 + ), "World size must be divisible by total model parallel size (TP + PP)." + + num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size + num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size + + # 2. Build data parallel groups + all_data_parallel_group_ranks = self._build_data_parallel_groups( + current_rank, + pipeline_model_parallel_size, + tensor_model_parallel_size, + num_pipeline_model_parallel_groups, + ) + + # 3. Build model parallel groups + self._build_model_parallel_groups( + current_rank, + data_parallel_size, + all_data_parallel_group_ranks, + ) + + # 4. Build tensor parallel groups + self._build_tensor_model_parallel_groups( + current_rank, + tensor_model_parallel_size, + num_tensor_model_parallel_groups, + ) + + # 5. Build pipeline and embedding parallel groups + self._build_pipeline_and_embedding_groups( + current_rank, + world_size, + num_pipeline_model_parallel_groups, + ) + + def _build_data_parallel_groups( + self, + current_rank: int, + pipeline_model_parallel_size: int, + tensor_model_parallel_size: int, + num_pipeline_model_parallel_groups: int, + ) -> List[List[int]]: + """ + Build data parallel groups + + Args: + current_rank (int): current rank + pipeline_model_parallel_size (int): the number of GPUs used to parallelize model tensor + tensor_model_parallel_size (int): the number of GPUs used to parallelize model pipeline + num_pipeline_model_parallel_groups (int): the number of pipeline parallel groups + + Returns: + List[List[int]]: all data parallel group ranks + """ + + assert self._DATA_PARALLEL_GROUP is None, "Data parallel group is already initialized." + + all_data_parallel_group_ranks = [] + for i in range(pipeline_model_parallel_size): + start_rank = i * num_pipeline_model_parallel_groups + end_rank = (i + 1) * num_pipeline_model_parallel_groups + + for j in range(tensor_model_parallel_size): + ranks = list( + range( + start_rank + j, + end_rank, + tensor_model_parallel_size, + ) + ) + + all_data_parallel_group_ranks.append(ranks) + group = dist.new_group(ranks) + if current_rank in ranks: + self._DATA_PARALLEL_GROUP = group + + return all_data_parallel_group_ranks + + def _build_model_parallel_groups( + self, + current_rank: int, + data_parallel_size: int, + all_data_parallel_group_ranks: List[List[int]], + ): + """ + Build model parallel groups + + Args: + current_rank (int): current rank + data_parallel_size (int): the number of GPUs used to parallelize data + all_data_parallel_group_ranks (List[List[int]]): all data parallel group ranks + """ + assert self._MODEL_PARALLEL_GROUP is None, "Model parallel group is already initialize" + + for i in range(data_parallel_size): + ranks = [data_parallel_ranks[i] for data_parallel_ranks in all_data_parallel_group_ranks] + + group = dist.new_group(ranks) + if current_rank in ranks: + self._MODEL_PARALLEL_GROUP = group + + def _build_tensor_model_parallel_groups( + self, + current_rank: int, + tensor_model_parallel_size: int, + num_tensor_model_parallel_groups: int, + ): + """ + Build tensor model parallel groups + + Args: + current_rank (int): current rank + tensor_model_parallel_size (int): the number of GPUs used to parallelize model tensor + num_tensor_model_parallel_groups (int): the number of tensor parallel groups + """ + assert self._TENSOR_MODEL_PARALLEL_GROUP is None, "Tensor parallel group is already initialized." + + for i in range(num_tensor_model_parallel_groups): + ranks = list( + range( + i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size, + ) + ) + + group = dist.new_group(ranks) + if current_rank in ranks: + self._TENSOR_MODEL_PARALLEL_GROUP = group + + def _build_pipeline_and_embedding_groups( + self, + current_rank: int, + world_size: int, + num_pipeline_model_parallel_groups: int, + ): + """ + Build pipeline and embedding groups + + Args: + current_rank (int): current rank + world_size (int): world size + num_pipeline_model_parallel_groups (int): the number of pipeline parallel groups + """ + + assert self._PIPELINE_MODEL_PARALLEL_GROUP is None, "Pipeline parallel group is already initialized." + assert self._EMBEDDING_GROUP is None, "Embedding group is already initialized." + + for i in range(num_pipeline_model_parallel_groups): + # step 1. create pipeline parallel groups + pipeline_ranks = list( + range( + i, + world_size, + num_pipeline_model_parallel_groups, + ) + ) + + pipeline_group = dist.new_group(pipeline_ranks) + if current_rank in pipeline_ranks: + self._PIPELINE_MODEL_PARALLEL_GROUP = pipeline_group + self._PIPELINE_GLOBAL_RANKS = pipeline_ranks + + # step 2. create embedding parallel group + if len(pipeline_ranks) > 1: + embedding_ranks = [pipeline_ranks[0], pipeline_ranks[-1]] + # first, last stage go to embedding parallel group + else: + embedding_ranks = pipeline_ranks + + embedding_group = dist.new_group(embedding_ranks) + if current_rank in embedding_ranks: + self._EMBEDDING_GROUP = embedding_group + + def model_parallel_is_initialized(self): + """Check if model and data parallel groups are initialized.""" + if ( + self._TENSOR_MODEL_PARALLEL_GROUP is None + or self._PIPELINE_MODEL_PARALLEL_GROUP is None + or self._DATA_PARALLEL_GROUP is None + ): + return False + return True + + def get_model_parallel_group(self): + """Get the model parallel group the caller rank belongs to.""" + assert self._MODEL_PARALLEL_GROUP is not None, "Model parallel group is not initialized." + return self._MODEL_PARALLEL_GROUP + + def get_tensor_model_parallel_group(self): + """Get the tensor parallel group the caller rank belongs to.""" + assert self._TENSOR_MODEL_PARALLEL_GROUP is not None, "Tensor parallel group is not initialized." + return self._TENSOR_MODEL_PARALLEL_GROUP + + def get_pipeline_model_parallel_group(self): + """Get the pipeline parallel group the caller rank belongs to.""" + assert self._PIPELINE_MODEL_PARALLEL_GROUP is not None, "Pipeline parallel group is not initialized." + return self._PIPELINE_MODEL_PARALLEL_GROUP + + def get_data_parallel_group(self): + """Get the data parallel group the caller rank belongs to.""" + assert self._DATA_PARALLEL_GROUP is not None, "Data parallel group is not initialized." + return self._DATA_PARALLEL_GROUP + + def get_embedding_group(self): + """Get the embedding group the caller rank belongs to.""" + assert self._EMBEDDING_GROUP is not None, "Embedding group is not initialized." + return self._EMBEDDING_GROUP + + def get_tensor_model_parallel_world_size(self): + """Return world size for the tensor parallel group.""" + if self._TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: + return self._TENSOR_MODEL_PARALLEL_WORLD_SIZE + return dist.get_world_size(group=self.get_tensor_model_parallel_group()) + + def set_tensor_model_parallel_world_size(self, world_size: int): + """Set the tensor parallel size""" + self._TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size + + def get_pipeline_model_parallel_world_size(self): + """Return world size for the pipeline parallel group.""" + if self._PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: + return self._PIPELINE_MODEL_PARALLEL_WORLD_SIZE + return dist.get_world_size(group=self.get_pipeline_model_parallel_group()) + + def set_pipeline_model_parallel_world_size(self, world_size: int): + """Set the pipeline parallel size""" + self._PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size + + def get_data_parallel_world_size(self): + """Return world size for the data parallel group.""" + return dist.get_world_size(group=self.get_data_parallel_group()) + + def get_tensor_model_parallel_rank(self): + """Return my rank for the tensor parallel group.""" + if self._TENSOR_MODEL_PARALLEL_RANK is not None: + return self._TENSOR_MODEL_PARALLEL_RANK + return dist.get_rank(group=self.get_tensor_model_parallel_group()) + + def set_tensor_model_parallel_rank(self, rank: int): + """Set tensor parallel rank.""" + self._TENSOR_MODEL_PARALLEL_RANK = rank + + def get_pipeline_model_parallel_rank(self): + """Return my rank for the pipeline parallel group.""" + if self._PIPELINE_MODEL_PARALLEL_RANK is not None: + return self._PIPELINE_MODEL_PARALLEL_RANK + return dist.get_rank(group=self.get_pipeline_model_parallel_group()) + + def set_pipeline_model_parallel_rank(self, rank: int): + """Set pipeline parallel rank.""" + self._PIPELINE_MODEL_PARALLEL_RANK = rank + + def get_tensor_model_parallel_src_rank(self): + """Calculate the global rank corresponding to the first local rank in the tensor parallel group.""" + global_rank = dist.get_rank() + local_world_size = self.get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + def get_pipeline_model_parallel_first_rank(self): + assert self._PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" + return self._PIPELINE_GLOBAL_RANKS[0] + + def get_pipeline_model_parallel_last_rank(self): + assert self._PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" + return self._PIPELINE_GLOBAL_RANKS[self.get_pipeline_model_parallel_world_size() - 1] + + def get_pipeline_model_parallel_next_rank(self): + assert self._PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" + + rank_in_pipeline = self.get_pipeline_model_parallel_rank() + world_size = self.get_pipeline_model_parallel_world_size() + return self._PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] + + def get_pipeline_model_parallel_prev_rank(self): + assert self._PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" + rank_in_pipeline = self.get_pipeline_model_parallel_rank() + world_size = self.get_pipeline_model_parallel_world_size() + return self._PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] + + def get_data_parallel_rank(self): + """Return my rank for the data parallel group.""" + return dist.get_rank(group=self.get_data_parallel_group()) + + def is_pipeline_first_stage(self): + """Return True if in the first pipeline parallel stage, False otherwise.""" + return self.get_pipeline_model_parallel_rank() == 0 + + def is_pipeline_last_stage(self): + """Return True if in the last pipeline parallel stage, False otherwise.""" + return self.get_pipeline_model_parallel_rank() == self.get_pipeline_model_parallel_world_size() - 1 + + def destroy_model_parallel(self): + """Set the groups to none.""" + self._TENSOR_MODEL_PARALLEL_GROUP = None + self._PIPELINE_MODEL_PARALLEL_GROUP = None + self._DATA_PARALLEL_GROUP = None + self._MODEL_PARALLEL_GROUP = None + self._EMBEDDING_GROUP = None + + @staticmethod + def _initialize_distributed(backend="nccl"): + """Initialize torch.distributed.""" + # Get rank and world size. + + rank = int(os.getenv("RANK", "0")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + + # Set the device id. + device = rank % torch.cuda.device_count() + if local_rank is not None: + device = local_rank + torch.cuda.set_device(device) + + # Call the init process. + init_method = "tcp://" + master_ip = os.getenv("MASTER_ADDR", "localhost") + master_port = os.getenv("MASTER_PORT", "6000") + init_method += master_ip + ":" + master_port + + dist.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + init_method=init_method, + ) + + def _parallelize( + self, + policies: List[Type[ParallelPolicy]], + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + ): + """ + Parallelize model by given model parallel sizes + + Args: + policies (List[Type[ParallelPolicy]]): List of policy class + tensor_model_parallel_size (int): tensor model parallel size + pipeline_model_parallel_size (int): pipeline model parallel size + """ + assert ( + pipeline_model_parallel_size == 1 + ), "Currently, pipeline model parallelism is not supported. You must set `pipeline_model_parallel_size` to 1." + + # TODO: Add Data distributed parallel working with Tensor model parallel (2D Parallelism) + # from torch.nn.parallel import DistributedDataParallel + + assert int(os.getenv("WORLD_SIZE")) == tensor_model_parallel_size, ( + "Currently Tensor model parallelism is not compatible with Data parallelism. " + "You must set to `tensor_model_parallel_size` to total world size." + ) + + if not dist.is_initialized(): + self._initialize_distributed() + + if self._is_uninitialized(): + self._initialize_model_parallel( + tensor_model_parallel_size_=tensor_model_parallel_size, + pipeline_model_parallel_size_=pipeline_model_parallel_size, + ) + + replacer = TensorReplacer( + model=self, + model_parallel_group=self.get_tensor_model_parallel_group(), + policies=policies, + ) + + replacer.replace_layers() + + # Lazy GPU memory allocation (Only cpu tensors are loaded onto all gpus) + for k, v in dict(self.state_dict()).items(): + if not v.is_cuda: + if torch.is_tensor(v): + rsetattr(self, k + ".data", v.to(torch.cuda.current_device())) + + else: + raise Exception("Model is already parallelized !") + + +def igetattr(obj, attr, *args): + """ + Indexed getattr function + + Examples: + >>> model = Model() + >>> igetattr(model, "weight[2]") + """ + if "[" in attr and "]" in attr: + attr = "".join("\t".join(attr.split("[")).split("]")).split("\t") + indexes = "[".join(attr[1:-1]).replace("[", "][") + indexes = "[" + indexes + "]" if len(indexes) >= 1 else indexes + return igetattr(obj, attr[0] + indexes)[int(attr[-1])] + else: + return getattr(obj, attr, *args) + + +def isetattr(obj, attr, val): + """ + Indexed setattr function + + Examples: + >>> model = Model() + >>> isetattr(model, "weight[2]", new_weight) + """ + if "[" in attr and "]" in attr: + element = attr.split("[")[0] + element_obj = getattr(obj, element) + attr = "".join("\t".join(attr.split("[")).split("]")).split("\t")[1:] + + for i in range(len(attr) - 1): + element_obj = element_obj[int(attr[i])] + + element_obj[int(attr[-1])] = val + else: + setattr(obj, attr, val) + + +def rgetattr(obj, attr, default=None): + """ + Recursive getattr function based on igetattr + + Examples: + >>> model = Model() + >>> rgetattr(model, "layer[2].attention.weight[3].data") + """ + + try: + left, right = attr.split(".", 1) + except BaseException: + return igetattr(obj, attr, default) + return rgetattr(igetattr(obj, left), right, default) + + +def rsetattr(obj, attr, val): + """ + Recursive setattr function based on isetattr + + Examples: + >>> model = Model() + >>> rgetattr(model, "layer[2].attention.weight[3].data", new_data) + """ + + try: + left, right = attr.split(".", 1) + except BaseException: + return isetattr(obj, attr, val) + return rsetattr(igetattr(obj, left), right, val) + + +def rhasattr(obj, attr): + """ + Recursive hasattr function based on igetattr + + Examples: + >>> model = Model() + >>> rhasattr(model, "layer[2].attention.weight[3].data") + True + """ + + try: + left, right = attr.split(".", 1) + except BaseException: + return hasattr(obj, attr) + try: + get = igetattr(obj, left) + except BaseException: + return False + return rhasattr(get, right) + + +class TensorSlicer(object): + r""" + An object that slices tensors into rows or columns as described in the Megatron-LM paper + + Args: + model_parallel_group: Distributed group for model parallelism + """ + + def __init__(self, model_parallel_group) -> None: + if dist.is_initialized() and model_parallel_group is not None: + self.gpu_index = dist.get_rank(model_parallel_group) + self.world_size = dist.get_world_size(model_parallel_group) + else: + self.gpu_index = 0 + self.world_size = 1 + + def slice_tensor( + self, + tensor: Dict, + attributes: Dict, + dim: int, + is_bias: bool, + ) -> Tuple: + """ + Slice tensors into rows or columns as described in the Megatron-LM paper + + Args: + tensor (Dict): tensor dictionaries + attributes (Dict): attributes dictionaries + dim (int): dimension for slicing + is_bias (bool): whether tensor is bias or not + + Returns: + Tuple: tuple of sliced tensors + """ + if not tensor: + return (None,) + + n_fused_list, reversed_list = [], [] + for (key_tensor, _), (key_attr, val_attr) in zip(tensor.items(), attributes.items()): + if key_tensor == key_attr: + n_fused, _reversed = val_attr + n_fused_list.append(n_fused) + reversed_list.append(_reversed) + + tensor = list(tensor.values()) + slices = [] + + for proj_layer, n_fused, _reversed in zip(tensor, n_fused_list, reversed_list): + device = torch.cuda.current_device() + dim = dim if not _reversed or is_bias else abs(dim - 1) + n_fused = 1 if not n_fused else n_fused + proj_layer = proj_layer.chunk(n_fused * self.world_size, dim=dim) + + if n_fused > 1: + ranks = (len(proj_layer) + self.world_size - 1) // self.world_size + proj_layer = [proj_layer[i * self.world_size : (i + 1) * self.world_size] for i in range(ranks)] + proj_layer = list(map(lambda x: torch.cat([*x], dim=-1), zip(*proj_layer))) + + proj_layer = proj_layer[self.gpu_index].to(device) + slices.append(proj_layer) + + return tuple(slices) + + def slice_weight_and_bias(self, tensors: Tuple, attributes: Tuple, dim: int, slice_bias: bool) -> Tuple: + """ + Slice weight and bias for model parallelization + + Args: + tensors (Tuple): tuple of weight and bias dictionaries + attributes (Tuple): tuple of weight attributes and bias attributes dictionaries + dim (int): dimension for slicing + slice_bias (bool): whether slice bias or not + + Returns: + Tuple: tuple of weights and biases + """ + + weight, bias = tensors + w_attr, b_attr = attributes + + weight = self.slice_tensor(weight, w_attr, dim=dim, is_bias=False) + + if slice_bias: + bias = self.slice_tensor(bias, b_attr, dim=0, is_bias=True) + else: + bias = tuple(bias.values()) + + return weight, bias + + def column_slice(self, tensors: Tuple, attributes: Tuple) -> Tuple: + """ + Slice tensors in the column direction. + + Args: + tensors (Tuple): tuple of weight and bias dictionaries + attributes (Tuple): tuple of weight attributes and bias attributes dictionaries + + Notes: + nn.Linear layer of Pytorch stores parameters like torch.Tensor(out_features, in_features) + So, dimension of column slice is 0 (out_features). + + Returns: + Tuple: tuple of weights and biases + """ + return self.slice_weight_and_bias( + tensors, + attributes=attributes, + dim=0, + slice_bias=True, + ) + + def row_slice(self, tensors: Tuple, attributes: Tuple) -> Tuple: + """ + Slice tensors in the row direction. + + Args: + tensors (Tuple): tuple of weight and bias dictionaries + attributes (Tuple): tuple of weight attributes and bias attributes dictionaries + + Notes: + nn.Linear layer of Pytorch stores parameters like torch.Tensor(out_features, in_features) + So, dimension of row slice is 1 (in_features). + + Returns: + Tuple: tuple of weights and biases + """ + return self.slice_weight_and_bias( + tensors, + attributes=attributes, + dim=1, + slice_bias=False, + ) + + +class TensorReplacer(object): + r""" + Replace original layer into Megatron-LM layer. + + Args: + model (nn.Module): Huggingface pre-trained transformer model + model_parallel_group (Any): process group for model parallelism + policies (List[Type[ParallelPolicy]]): parallelization policy classes + """ + + def __init__( + self, + model: Union[nn.Module, ParallelizationMixin], + model_parallel_group: Any, + policies: List[Type[ParallelPolicy]], + ) -> None: + self.model = model + self.config = model.config + self.model_parallel_group = model_parallel_group + self.world_size = dist.get_world_size(self.model_parallel_group) + self.slicer = TensorSlicer(self.model_parallel_group) + self.policies = policies + + def replace_layers(self) -> None: + """Replace original huggingface layers to Megtraon tensor sliced layers""" + for policy in self.policies: + self.replace_to_megatron_layers(self.model, policy) + + def replace_to_megatron_layers( + self, + model: nn.Module, + policy_cls: Type[ParallelPolicy], + ) -> nn.Module: + """ + Replace original layers to sliced layers + + Args: + model (nn.Module): model weight + policy_cls (Type[ParallelPolicy]): class of policy + + Returns: + nn.Module: parallelized parameters + """ + for name, child in model.named_children(): + if child.__class__ == policy_cls.original_layer_class(): + policy = policy_cls(layer=child) + arguments = policy.replace_arguments(self.config, self.world_size) + + for k, v in arguments.items(): + with suppress(Exception): + rsetattr(policy, f"layer.{k}", v) + + megatron_layer = self.convert_to_megatron_layer(policy) + rsetattr(model, name, megatron_layer) + + self.replace_to_megatron_layers(child, policy_cls) + + return model + + def set_parameters( + self, + policy: ParallelPolicy, + weight_name: Dict[str, torch.Tensor], + bias_name: Dict[str, torch.Tensor], + weight_param: Tuple[torch.Tensor], + bias_param: Tuple[torch.Tensor], + suffix: str = "data", + ) -> ParallelPolicy: + """ + Set sliced parameters into original model + + Args: + policy (ParallelPolicy): policy object + weight_name (Dict[str, Tensor]): names of layer's weight + bias_name (Dict[str, Tensor]): names of layer's bias + weight_param (Tuple[Tensor]): parameters of sliced weight + bias_param (Tuple[Tensor]): parameters of sliced bias + suffix (str): name of suffix in the parameters + + Returns: + ParallelPolicy: policy object + """ + for name, param in zip(weight_name, weight_param): + rsetattr(policy, f"{name}.{suffix}", param) + self.resize_layer(policy, name, param.size()) + + for name, param in zip(bias_name, bias_param): + rsetattr(policy, f"{name}.{suffix}", param) + + return policy + + @staticmethod + def resize_layer(policy: ParallelPolicy, name: str, size: torch.Size) -> None: + """ + Apply resize layer size to original layer object + + Args: + policy (ParallelPolicy): policy object + name (str): name of parameters + size (torch.Size): size of resized parameters + """ + layer_name = ".".join(f"{name}".split(".")[:-1]) + if rhasattr(policy, f"{layer_name}.nf"): + rsetattr(policy, f"{layer_name}.nf", size[1]) + + else: + for i, direction in enumerate(["out", "in"]): + if rhasattr(policy, f"{layer_name}.{direction}_{name}"): + rsetattr(policy, f"{layer_name}.{direction}_{name}", size[i]) + + def convert_to_megatron_layer(self, policy: ParallelPolicy) -> nn.Module: + """ + Convert original layers to sliced layers. + + Args: + policy (ParallelPolicy): policy object + + Returns: + nn.Module: sliced model layer + """ + attn_qkvw, attn_qkvb, attn_qkvw_attr, attn_qkvb_attr = policy.preprocess( + policy.attn_qkv(), + self.model_parallel_group, + ) + attn_outw, attn_outb, attn_outw_attr, attn_outb_attr = policy.preprocess( + policy.attn_out(), + self.model_parallel_group, + ) + mlp_inw, mlp_inb, mlp_inw_attr, mlp_inb_attr = policy.preprocess( + policy.mlp_in(), + self.model_parallel_group, + ) + mlp_outw, mlp_outb, mlp_outw_attr, mlp_outb_attr = policy.preprocess( + policy.mlp_out(), + self.model_parallel_group, + ) + + policy = self.set_parameters( + policy, + attn_qkvw, + attn_qkvb, + *self.slicer.column_slice( + (attn_qkvw, attn_qkvb), + (attn_qkvw_attr, attn_qkvb_attr), + ), + ) + + policy = self.set_parameters( + policy, + attn_outw, + attn_outb, + *self.slicer.row_slice( + (attn_outw, attn_outb), + (attn_outw_attr, attn_outb_attr), + ), + ) + + policy = self.set_parameters( + policy, + mlp_inw, + mlp_inb, + *self.slicer.column_slice( + (mlp_inw, mlp_inb), + (mlp_inw_attr, mlp_inb_attr), + ), + ) + + policy = self.set_parameters( + policy, + mlp_outw, + mlp_outb, + *self.slicer.row_slice( + (mlp_outw, mlp_outb), + (mlp_outw_attr, mlp_outb_attr), + ), + ) + + return policy.layer + + +def _broadcast(model_parallel_group, _input: torch.Tensor): + """Pass the input to the model parallel region.""" + return _input + + +def _reduce(model_parallel_group, _input: torch.Tensor): + """All-reduce the the input tensor across model parallel group.""" + if dist.get_world_size(model_parallel_group) == 1: + return _input + + dist.all_reduce( + _input, + group=model_parallel_group, + ) + + return _input + + +def _gather(model_parallel_group, _input: torch.Tensor): + """Gather tensors and concatenate along the last dimension.""" + world_size = dist.get_world_size(model_parallel_group) + + if world_size == 1: + return _input + + _input_list = [torch.ones_like(_input) for _ in range(world_size)] + + dist.all_gather( + _input_list, + _input, + group=model_parallel_group, + ) + + return torch.cat(_input_list, dim=-1).contiguous() + + +def _scatter(model_parallel_group, _input: torch.Tensor): + """Split the tensor along its last dimension and keep the corresponding slice.""" + world_size = dist.get_world_size(model_parallel_group) + + if world_size == 1: + return _input + + _input_list = torch.chunk(_input, world_size, dim=-1) + rank = dist.get_rank(model_parallel_group) + return _input_list[rank].contiguous() + + +class BroadcastFunction(Function): + model_parallel_group = None + + @staticmethod + def forward(ctx, _input: torch.Tensor) -> torch.Tensor: + return _broadcast(BroadcastFunction.model_parallel_group, _input) + + @staticmethod + def backward(ctx, grad_outputs: torch.Tensor) -> torch.Tensor: + return _reduce(BroadcastFunction.model_parallel_group, grad_outputs) + + +class ReduceFunction(Function): + model_parallel_group = None + + @staticmethod + def forward(ctx, _input: torch.Tensor) -> torch.Tensor: + return _reduce(ReduceFunction.model_parallel_group, _input) + + @staticmethod + def backward(ctx, grad_outputs: torch.Tensor) -> torch.Tensor: + return _broadcast(ReduceFunction.model_parallel_group, grad_outputs) + + +class GatherFunction(Function): + model_parallel_group = None + + @staticmethod + def forward(ctx, _input: torch.Tensor) -> torch.Tensor: + return _gather(GatherFunction.model_parallel_group, _input) + + @staticmethod + def backward(ctx, grad_outputs: torch.Tensor) -> torch.Tensor: + return _scatter(GatherFunction.model_parallel_group, grad_outputs) + + +class ScatterFunction(Function): + model_parallel_group = None + + @staticmethod + def forward(ctx, _input: torch.Tensor) -> torch.Tensor: + return _scatter(ScatterFunction.model_parallel_group, _input) + + @staticmethod + def backward(ctx, grad_outputs: torch.Tensor) -> torch.Tensor: + return _gather(ScatterFunction.model_parallel_group, grad_outputs) + + +class ParallelModule(object): + def inject_model_parallel_group(self, model_parallel_group): + self.model_parallel_group = model_parallel_group + self.broadcast_func = BroadcastFunction() + self.reduce_func = ReduceFunction() + self.gather_func = GatherFunction() + self.scatter_func = ScatterFunction() + + self.broadcast_func.model_parallel_group = model_parallel_group + self.reduce_func.model_parallel_group = model_parallel_group + self.gather_func.model_parallel_group = model_parallel_group + self.scatter_func.model_parallel_group = model_parallel_group + + # model parallel group will be injected by ParallelPolicy. + + def broadcast(self, _input): + if self.model_parallel_group is not None: + if dist.get_world_size(self.model_parallel_group) > 1: + return self.broadcast_func.apply(_input) + + return _input + + def reduce(self, _input): + if self.model_parallel_group is not None: + if dist.get_world_size(self.model_parallel_group) > 1: + return self.reduce_func.apply(_input) + + return _input + + def gather(self, _input): + if self.model_parallel_group is not None: + return self.gather_func.apply(_input) + + return _input + + def scatter(self, _input): + if self.model_parallel_group is not None: + return self.scatter_func.apply(_input) + + return _input + + +class ColumnParallelLinear(nn.Linear, ParallelModule): + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = self.broadcast(input) + outputs = input.matmul(self.weight.t()) + + if self.bias is not None: + outputs += self.bias + + return outputs + + +class RowParallelLinear(nn.Linear, ParallelModule): + def forward(self, input: torch.Tensor) -> torch.Tensor: + outputs = input.matmul(self.weight.t()) + outputs = self.reduce(outputs) + + if self.bias is not None: + outputs += self.bias + + return outputs diff --git a/tests/parallelism/__init__.py b/tests/parallelism/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/parallelism/test_gpt_neo_inference.py b/tests/parallelism/test_gpt_neo_inference.py new file mode 100644 index 00000000000000..c7b0805749d058 --- /dev/null +++ b/tests/parallelism/test_gpt_neo_inference.py @@ -0,0 +1,70 @@ +""" +Single GPU inference without model parallelism +It is for checking whether model parallel output is same with non-parallel output. +(just copy and paste !) + +FP32: +python test_gpt_neo_inference.py + +FP16: +python test_gpt_neo_inference.py --precision 16 +""" + +from argparse import ArgumentParser + +import torch + +from transformers import GPT2TokenizerFast, GPTNeoForCausalLM + + +parser = ArgumentParser() +parser.add_argument("--precision", default=32, type=int) +args = parser.parse_args() + +precision = args.precision +assert precision in [16, 32], "`--precision` must be on of [16, 32]" +os.environ["TOKENIZERS_PARALLELISM"] = "true" + + +tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-1.3B") +tokenizer.pad_token = tokenizer.eos_token + +model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B") + +if precision == 16: + model = model.half().cuda() +else: + model = model.cuda() + +tokenized = tokenizer("Hello. My name is Kevin. Today,", return_tensors="pt") +input_ids = tokenized.input_ids.cuda() +attention_mask = tokenized.attention_mask.cuda() + +output = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, +) + +decoded = tokenizer.batch_decode(output) +print(f"generation output:\n{decoded}\n") +print(f"GPU0: {torch.cuda.memory_allocated(0)}\n") + +""" +Note nvidia-smi shows the reserved cache area, it is difficult to check the exact allocated memory. +To check the allocated memory state well, use torch.cuda.memory_allocated() rather than nvidia-smi. + +FP32 result: +$ python test_gpt_neo_inference.py +generation output: +['Hello. My name is Kevin. Today, I’m going to be talking about the best'] + +GPU0: 5312648704 + + +FP16 result: +$ python test_gpt_neo_inference.py --precision 16 +generation output: +['Hello. My name is Kevin. Today, I’m going to be talking about the best'] + +GPU0: 2681497088 +""" diff --git a/tests/parallelism/test_gpt_neo_inference_with_model_parallel.py b/tests/parallelism/test_gpt_neo_inference_with_model_parallel.py new file mode 100644 index 00000000000000..8d5967f1bbdf0e --- /dev/null +++ b/tests/parallelism/test_gpt_neo_inference_with_model_parallel.py @@ -0,0 +1,86 @@ +""" +if you have 4 gpus run follow instruction +(just copy and paste !) + +FP32: +python -m torch.distributed.launch --nproc_per_node 4 test_gpt_neo_inference_with_model_parallel.py + +FP16: +python -m torch.distributed.launch --nproc_per_node 4 test_gpt_neo_inference_with_model_parallel.py --precision 16 +""" +from argparse import ArgumentParser + +import torch +import torch.distributed as dist + +from transformers import GPT2TokenizerFast, GPTNeoForCausalLM + + +parser = ArgumentParser() +parser.add_argument("--precision", default=32, type=int) +parser.add_argument("--local_rank", default=0, type=int) +args = parser.parse_args() + +precision = args.precision +local_rank = args.local_rank +assert precision in [16, 32], "`--precision` must be on of [16, 32]" +os.environ["TOKENIZERS_PARALLELISM"] = "true" + + +def print_rank_0(string="", **kwargs): + if dist.is_initialized(): + if dist.get_rank() == 0: + print(string, **kwargs) + else: + print(string, **kwargs) + + +tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-1.3B") +tokenizer.pad_token = tokenizer.eos_token + +model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B") + +if precision == 16: + model = model.half() + +model.parallelize(tensor_model_parallel_size=4) +# The model is parallelized with just one line of code. + +tokenized = tokenizer("Hello. My name is Kevin. Today,", return_tensors="pt") +input_ids = tokenized.input_ids.cuda() +attention_mask = tokenized.attention_mask.cuda() + +output = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, +) + +decoded = tokenizer.batch_decode(output) +print_rank_0(f"generation output:\n{decoded}\n") +print(f"GPU{local_rank}: {torch.cuda.memory_allocated(local_rank)}") + +""" +Note nvidia-smi shows the reserved cache area, it is difficult to check the exact allocated memory. +To check the allocated memory state well, use torch.cuda.memory_allocated() rather than nvidia-smi. + +FP32 result: +$ python -m torch.distributed.launch --nproc_per_node 4 test_gpt_neo_inference_with_model_parallel.py +generation output: +['Hello. My name is Kevin. Today, I’m going to be talking about the best'] + +GPU1: 1688180224 +GPU0: 1688180224 +GPU3: 1688180224 +GPU2: 1688180224 + + +FP16 result: +$ python -m torch.distributed.launch --nproc_per_node 4 test_gpt_neo_inference_with_model_parallel.py --precision 16 +generation output: +['Hello. My name is Kevin. Today, I’m going to be talking about the best'] + +GPU0: 869262848 +GPU1: 869262848 +GPU3: 869262848 +GPU2: 869262848 +""" diff --git a/tests/parallelism/test_gpt_neo_training.py b/tests/parallelism/test_gpt_neo_training.py new file mode 100644 index 00000000000000..9f42f09b54e459 --- /dev/null +++ b/tests/parallelism/test_gpt_neo_training.py @@ -0,0 +1,146 @@ +""" +Single GPU training without model parallelism +It is for checking whether model parallel works wellwehn +(just copy and paste !) + +python test_gpt_neo_training.py.py +""" +import os + +import torch +import torch.distributed as dist +import wandb +from datasets import load_dataset +from torch.optim.adam import Adam +from torch.utils.data import DataLoader +from tqdm import tqdm + +from transformers import GPT2TokenizerFast, GPTNeoForCausalLM + + +def print_rank_0(string="", **kwargs): + if dist.is_initialized(): + if dist.get_rank() == 0: + print(string, **kwargs) + else: + print(string, **kwargs) + + +os.environ["TOKENIZERS_PARALLELISM"] = "true" +model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B") +model = model.cuda() +tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-1.3B") +tokenizer.pad_token = tokenizer.eos_token + + +def preprocess_with_uniform_length_batching( + dataset, + name, + num_dataset=1000, + min_character=200, + max_character=500, +): + print_rank_0(f"prerocessing {name} dataset...") + + _dataset = [] + for sample in dataset[name]["text"]: + if min_character <= len(sample) <= max_character: + _dataset.append(sample) + if len(_dataset) == num_dataset: + break + + # order: longer => shorter + return sorted(_dataset, key=len, reverse=True) + + +dataset = load_dataset("wikitext", "wikitext-103-raw-v1") +dataset = preprocess_with_uniform_length_batching(dataset, name="train") +optimizer = Adam(model.parameters(), lr=3e-6, weight_decay=1e-7) +total_epoch, batch_size = 1, 16 + +data_loader = DataLoader( + dataset, + batch_size=batch_size, + drop_last=True, + num_workers=os.cpu_count(), +) + + +wandb.init( + project="test_gpt_neo_training", + name=f"gpt_neo_without_model_parallel_batch_size_{batch_size}", +) + +print_rank_0(f"start training ... check wandb monitor !") +for current_epoch in range(total_epoch): + for i, batch in tqdm(enumerate(data_loader), total=len(data_loader)): + # dynamic padding + tokens = tokenizer(batch, return_tensors="pt", padding=True, truncation=True) + input_ids = tokens.input_ids.cuda() + attention_mask = tokens.attention_mask.cuda() + + optimizer.zero_grad() + loss = model(input_ids, labels=input_ids, attention_mask=attention_mask).loss + wandb.log({"training loss": loss, "training ppl": torch.exp(loss)}) + + loss.backward() + optimizer.step() + +""" + + + +> before training: +|===============================+======================+======================| +| 0 A100-SXM4-40GB Off | 00000000:00:04.0 Off | 0 | +| N/A 36C P0 61W / 400W | 6380MiB / 40537MiB | 0% Default | +| | | Disabled | ++-------------------------------+----------------------+----------------------+ +| 1 A100-SXM4-40GB Off | 00000000:00:05.0 Off | 0 | +| N/A 34C P0 53W / 400W | 3MiB / 40537MiB | 0% Default | +| | | Disabled | ++-------------------------------+----------------------+----------------------+ +| 2 A100-SXM4-40GB Off | 00000000:00:06.0 Off | 0 | +| N/A 35C P0 58W / 400W | 3MiB / 40537MiB | 0% Default | +| | | Disabled | ++-------------------------------+----------------------+----------------------+ +| 3 A100-SXM4-40GB Off | 00000000:00:07.0 Off | 0 | +| N/A 36C P0 59W / 400W | 3MiB / 40537MiB | 0% Default | +| | | Disabled | ++-------------------------------+----------------------+----------------------+ + + +> during training: +|===============================+======================+======================| +| 0 A100-SXM4-40GB Off | 00000000:00:04.0 Off | 0 | +| N/A 43C P0 64W / 400W | 40375MiB / 40537MiB | 2% Default | +| | | Disabled | ++-------------------------------+----------------------+----------------------+ +| 1 A100-SXM4-40GB Off | 00000000:00:05.0 Off | 0 | +| N/A 34C P0 53W / 400W | 0MiB / 40537MiB | 0% Default | +| | | Disabled | ++-------------------------------+----------------------+----------------------+ +| 2 A100-SXM4-40GB Off | 00000000:00:06.0 Off | 0 | +| N/A 35C P0 58W / 400W | 0MiB / 40537MiB | 0% Default | +| | | Disabled | ++-------------------------------+----------------------+----------------------+ +| 3 A100-SXM4-40GB Off | 00000000:00:07.0 Off | 0 | +| N/A 36C P0 59W / 400W | 0MiB / 40537MiB | 0% Default | +| | | Disabled | ++-------------------------------+----------------------+----------------------+ + + +> after training: +100%|██████████████████████████████████████████████████████| 62/62 [00:25<00:00, 2.39it/s] +wandb: training loss 2.50484 +wandb: training ppl 12.24162 +wandb: _runtime 33 +wandb: _timestamp 1630230185 +wandb: _step 61 +wandb: Run history: +wandb: training loss █▇▆▆▆▅▆▅▅▆▄▅▅▅▅▄▄▆▄▄▄▄▂▅▅▆▄▁▅▄▄▄▁▂▃▅▄▄▄▄ +wandb: training ppl █▆▄▄▄▃▄▄▄▄▂▄▄▃▄▂▃▄▂▃▃▃▁▃▃▄▃▁▃▃▂▃▁▂▂▃▃▃▂▃ +wandb: _runtime ▁▁▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇████ +wandb: _timestamp ▁▁▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇████ +wandb: _step ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███ +""" diff --git a/tests/parallelism/test_gpt_neo_training_with_model_parallel.py b/tests/parallelism/test_gpt_neo_training_with_model_parallel.py new file mode 100644 index 00000000000000..8a34a289a0e178 --- /dev/null +++ b/tests/parallelism/test_gpt_neo_training_with_model_parallel.py @@ -0,0 +1,157 @@ +""" +if you have 4 gpus run follow instruction +(just copy and paste !) + +python -m torch.distributed.launch --nproc_per_node 4 test_gpt_neo_training_with_model_parallel.py +""" + +import os +from argparse import ArgumentParser + +import torch +import torch.distributed as dist +import wandb +from datasets import load_dataset +from torch.optim.adam import Adam +from torch.utils.data import DataLoader +from tqdm import tqdm + +from transformers import GPT2TokenizerFast, GPTNeoForCausalLM + + +def print_rank_0(string="", **kwargs): + if dist.is_initialized(): + if dist.get_rank() == 0: + print(string, **kwargs) + else: + print(string, **kwargs) + + +parser = ArgumentParser() +parser.add_argument("--local_rank", default=0, type=int) +local_rank = parser.parse_args().local_rank + +os.environ["TOKENIZERS_PARALLELISM"] = "true" +model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B") +model.parallelize(tensor_model_parallel_size=4) +tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-1.3B") +tokenizer.pad_token = tokenizer.eos_token + + +def preprocess_with_uniform_length_batching( + dataset, + name, + num_dataset=1000, + min_character=200, + max_character=500, +): + print_rank_0(f"prerocessing {name} dataset...") + + _dataset = [] + for sample in dataset[name]["text"]: + if min_character <= len(sample) <= max_character: + _dataset.append(sample) + if len(_dataset) == num_dataset: + break + + # order: longer => shorter + return sorted(_dataset, key=len, reverse=True) + + +dataset = load_dataset("wikitext", "wikitext-103-raw-v1") +dataset = preprocess_with_uniform_length_batching(dataset, name="train") +optimizer = Adam(model.parameters(), lr=3e-6, weight_decay=1e-7) +total_epoch, batch_size = 1, 16 + +data_loader = DataLoader( + dataset, + batch_size=batch_size, + drop_last=True, + num_workers=os.cpu_count(), +) + + +if local_rank == 0: + wandb.init( + project="test_gpt_neo_training", + name=f"gpt_neo_with_model_parallel_batch_size_{batch_size}", + ) + +print_rank_0(f"start training ... check wandb monitor !") +for current_epoch in range(total_epoch): + for i, batch in tqdm(enumerate(data_loader), total=len(data_loader)): + # dynamic padding + tokens = tokenizer(batch, return_tensors="pt", padding=True, truncation=True) + input_ids = tokens.input_ids.cuda() + attention_mask = tokens.attention_mask.cuda() + + optimizer.zero_grad() + loss = model(input_ids, labels=input_ids, attention_mask=attention_mask).loss + + if local_rank == 0: + wandb.log({"training loss": loss, "training ppl": torch.exp(loss)}) + + loss.backward() + optimizer.step() + +""" + + +> before training: +|===============================+======================+======================| +| 0 A100-SXM4-40GB Off | 00000000:00:04.0 Off | 0 | +| N/A 36C P0 60W / 400W | 3330MiB / 40537MiB | 0% Default | +| | | Disabled | ++-------------------------------+----------------------+----------------------+ +| 1 A100-SXM4-40GB Off | 00000000:00:05.0 Off | 0 | +| N/A 35C P0 59W / 400W | 3330MiB / 40537MiB | 0% Default | +| | | Disabled | ++-------------------------------+----------------------+----------------------+ +| 2 A100-SXM4-40GB Off | 00000000:00:06.0 Off | 0 | +| N/A 36C P0 63W / 400W | 3330MiB / 40537MiB | 0% Default | +| | | Disabled | ++-------------------------------+----------------------+----------------------+ +| 3 A100-SXM4-40GB Off | 00000000:00:07.0 Off | 0 | +| N/A 37C P0 64W / 400W | 3330MiB / 40537MiB | 0% Default | +| | | Disabled | ++-------------------------------+----------------------+----------------------+ + + +> during training: +|===============================+======================+======================| +| 0 A100-SXM4-40GB Off | 00000000:00:04.0 Off | 0 | +| N/A 51C P0 189W / 400W | 19900MiB / 40537MiB | 79% Default | +| | | Disabled | ++-------------------------------+----------------------+----------------------+ +| 1 A100-SXM4-40GB Off | 00000000:00:05.0 Off | 0 | +| N/A 49C P0 247W / 400W | 19980MiB / 40537MiB | 91% Default | +| | | Disabled | ++-------------------------------+----------------------+----------------------+ +| 2 A100-SXM4-40GB Off | 00000000:00:06.0 Off | 0 | +| N/A 50C P0 267W / 400W | 20004MiB / 40537MiB | 88% Default | +| | | Disabled | ++-------------------------------+----------------------+----------------------+ +| 3 A100-SXM4-40GB Off | 00000000:00:07.0 Off | 0 | +| N/A 51C P0 256W / 400W | 19896MiB / 40537MiB | 86% Default | +| | | Disabled | ++-------------------------------+----------------------+----------------------+ + + +> after training: +100%|██████████████████████████████████████████████████████| 62/62 [00:22<00:00, 2.71it/s] +100%|██████████████████████████████████████████████████████| 62/62 [00:16<00:00, 3.69it/s] +100%|██████████████████████████████████████████████████████| 62/62 [00:22<00:00, 2.75it/s] +100%|██████████████████████████████████████████████████████| 62/62 [00:22<00:00, 2.75it/s] +wandb: training loss 2.50473 +wandb: training ppl 12.24022 +wandb: _runtime 24 +wandb: _timestamp 1630230355 +wandb: _step 61 +wandb: Run history: +wandb: training loss █▇▆▆▆▅▆▅▅▆▄▅▅▅▅▄▄▆▄▄▄▄▂▅▅▆▄▁▅▄▄▄▁▂▃▅▄▄▄▄ +wandb: training ppl █▆▄▄▄▃▄▄▄▄▂▄▄▃▄▂▃▄▂▃▃▃▁▃▃▄▃▁▃▃▂▃▁▂▂▃▃▃▂▃ +wandb: _runtime ▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇██ +wandb: _timestamp ▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇██ +wandb: _step ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███ +wandb: +"""