Skip to content

Commit

Permalink
Merge branch 'ds-uly' of github.com:samadejacobs/transformers into de…
Browse files Browse the repository at this point in the history
…epspeed-sp-4.45.2
  • Loading branch information
thepowerfuldeez committed Oct 10, 2024
2 parents 66e08db + 8766b91 commit b8c0b41
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 1 deletion.
46 changes: 46 additions & 0 deletions docs/source/en/deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,52 @@ Using multiple GPUs with ZeRO-3 for generation requires synchronizing the GPUs b

For Transformers>=4.28, if `synced_gpus` is automatically set to `True` if multiple GPUs are detected during generation.

### Non-Trainer Sequence Parallelism
DeepSpeed sequence parallelism, also known as [DeepSpeed Ulysses](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/README.md), is a distributed training technique targeting long context LLM problems. Sequence parallelism would allow for a virtually indefinite growth in sequence length and model size with an increase in GPUs, unlimited by single GPU memory. DeepSpeed sequence parallelism is compatible with HuggingFace Transformers by adding 'sequence_parallel_size' and 'data_parallel_size' to the DeepSpeed configuration. Additionally, it's required that the user’s script correctly shard the input data along the sequence dimension.

```py
ds_config {
'sequence_parallel_size': 2,
'data_parallel_size': 1,
......
......
}

config = transformers.AutoConfig.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(model_name,
config=config,
attn_implementation="flash_attention_2")

model, _, _, _ = deepspeed.initialize(model=model,
model_parameters=model.parameters(),
config=ds_config,
dist_init_required=True,)


spg = model.get_sequence_parallel_group()
seq_parallel_world_size = dist.get_world_size(spg)
seq_parallel_rank = dist.get_rank(spg)

for n, batch in enumerate(data_loader):
seq_length = batch["input_ids"].size(1)
assert seq_length % seq_parallel_world_size == 0
sub_seq_length = seq_length // seq_parallel_world_size
sub_seq_start = seq_parallel_rank * sub_seq_length
sub_seq_end = (seq_parallel_rank + 1) * sub_seq_length

batch["input_ids"] = batch["input_ids"][:, sub_seq_start:sub_seq_end]
batch["labels"] = batch["labels"][:, sub_seq_start:sub_seq_end]

.......

```

The HuggingFace Transformers will internally invoke DeepSpeed Ulysses to take advantage of multi-GPU optimization during the pretraining, posttraining, and fine-tuning of long context LLMs. DeepSpeed sequence parallelism is compatible with FlashAttention and is fully supported. A detailed example script is available [here](https://github.com/microsoft/DeepSpeedExamples/blob/uly-hf/post_training/sequence_parallelism/test_ulysses.py).

Also, integration with the [`Trainer`] is underway, appropriate documentation will be updated once [`Trainer`] integration feature is available.


## Troubleshoot

When you encounter an issue, you should consider whether DeepSpeed is the cause of the problem because often it isn't (unless it's super obviously and you can see DeepSpeed modules in the exception)! The first step should be to retry your setup without DeepSpeed, and if the problem persists, then you can report the issue. If the issue is a core DeepSpeed problem and unrelated to the Transformers integration, open an Issue on the [DeepSpeed repository](https://github.com/microsoft/DeepSpeed).
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"deepspeed_load_checkpoint",
"deepspeed_optim_sched",
"is_deepspeed_available",
"is_deepspeed_sp_enabled",
"is_deepspeed_zero3_enabled",
"set_hf_deepspeed_config",
"unset_hf_deepspeed_config",
Expand Down Expand Up @@ -149,6 +150,7 @@
deepspeed_load_checkpoint,
deepspeed_optim_sched,
is_deepspeed_available,
is_deepspeed_sp_enabled,
is_deepspeed_zero3_enabled,
set_hf_deepspeed_config,
unset_hf_deepspeed_config,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,12 @@ def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_str
raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}")
else:
raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}")


def is_deepspeed_sp_enabled():
if is_deepspeed_available():
from deepspeed.utils import groups

return groups._get_sequence_parallel_world_size() > 1
else:
return False
25 changes: 25 additions & 0 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,18 @@
import torch
import torch.nn.functional as F

from .integrations.deepspeed import ( # DeepSpeed seq parallelism (aka Ulysses)
is_deepspeed_available,
is_deepspeed_sp_enabled,
)
from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal


if is_deepspeed_available():
from deepspeed.sequence.layer import _SeqAllToAll
from deepspeed.utils import groups as ds_comm_groups


if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from flash_attn import flash_attn_func, flash_attn_varlen_func
Expand Down Expand Up @@ -220,6 +229,16 @@ def _flash_attention_forward(
deterministic (`bool`, *optional*):
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
"""
if is_deepspeed_sp_enabled():
spg = ds_comm_groups._get_sequence_parallel_group()
# qkv tensors are of shape (batch_size, seq_len, num_heads, head_dim)
scatter_idx = 2 # Scatter on num_heads dimension
gather_idx = 1 # Gather on seq_len dimension
batch_dim_idx = 0 # Synonymous with the batch_first==true
query_states = _SeqAllToAll.apply(spg, query_states, scatter_idx, gather_idx, batch_dim_idx)
key_states = _SeqAllToAll.apply(spg, key_states, scatter_idx, gather_idx, batch_dim_idx)
value_states = _SeqAllToAll.apply(spg, value_states, scatter_idx, gather_idx, batch_dim_idx)

if not use_top_left_mask:
causal = is_causal
else:
Expand Down Expand Up @@ -298,4 +317,10 @@ def _flash_attention_forward(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
)

if is_deepspeed_sp_enabled():
scatter_idx = 1 # Scatter back on seq_len dimension
gather_idx = 2 # Gather on num_heads dimension
batch_dim_idx = 0
attn_output = _SeqAllToAll.apply(spg, attn_output, scatter_idx, gather_idx, batch_dim_idx)

return attn_output
21 changes: 20 additions & 1 deletion tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@ def require_deepspeed_aio(test_case):
if is_deepspeed_available():
from deepspeed.utils import logger as deepspeed_logger # noqa
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
from transformers.integrations.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled # noqa
from transformers.integrations.deepspeed import (
deepspeed_config,
is_deepspeed_zero3_enabled,
is_deepspeed_sp_enabled,
) # noqa


def get_launcher(distributed=False):
Expand Down Expand Up @@ -1359,3 +1363,18 @@ def test_clm_from_config_zero3_fp16(self):
with CaptureStderr() as cs:
execute_subprocess_async(cmd, env=self.get_env())
self.assertIn("Detected DeepSpeed ZeRO-3", cs.err)

@parameterized.expand([2, 4, 8, 16])
@require_torch_accelerator
@require_torch_multi_accelerator
def test_deepspeed_sp(self, sp_size):
# Check if deepspeed_sp is enabled
# Run deepspeed sp with 2 GPUs and different sp_size
self.assertFalse(is_deepspeed_sp_enabled())
ds_args = [f"--sequence-length={sp_size}"]
script = [f"{self.test_file_dir_str}/test_ulysses.py"]
distributed = True
launcher = get_launcher(distributed)

cmd = launcher + script + ds_args
execute_subprocess_async(cmd, env=self.get_env())
81 changes: 81 additions & 0 deletions tests/deepspeed/test_ulysses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import sys

import torch
from deepspeed import initialize

from transformers import AutoModel
from transformers.integrations.deepspeed import is_deepspeed_sp_enabled # noqa
from transformers.modeling_flash_attention_utils import _flash_attention_forward


# Call transformer flash attention with and without deepspeed sp enabled and compare they match
def test_transformer_flash_attention(seq_len=2) -> None:
model = AutoModel.from_pretrained("bert-base-uncased")
batch_size = 2

# Test with deepspeed sp
sp_size = 2
dp_size = 1
ds_engine, _, _, _ = initialize(
model=model,
config_params={
"train_batch_size": batch_size,
"data_parallel_size": dp_size,
"sequence_parallel_size": sp_size,
},
)

assert is_deepspeed_sp_enabled()

seq_len = seq_len
hidden_dim = 16
num_heads = 4
head_dim = hidden_dim // num_heads
# Create input tensors
input_tensor = torch.randn(batch_size, seq_len, num_heads, head_dim, device=ds_engine.device)
input_tensor = input_tensor.half()
attention_mask = None
q, k, v = input_tensor, input_tensor, input_tensor

output_tensor = _flash_attention_forward(q, k, v, attention_mask, query_length=seq_len, is_causal=False)
assert output_tensor is not None
assert output_tensor.shape == (batch_size, seq_len, num_heads, head_dim)

# Now test without deepspeed sp
sp_size = 1
dp_size = 2
ds_engine, _, _, _ = initialize(
model=model,
config_params={
"train_batch_size": batch_size,
"data_parallel_size": dp_size,
"sequence_parallel_size": sp_size,
},
)
assert not is_deepspeed_sp_enabled()

output_tensor_no_sp = _flash_attention_forward(q, k, v, attention_mask, query_length=seq_len, is_causal=False)
assert output_tensor_no_sp is not None
assert output_tensor_no_sp.shape == (batch_size, seq_len, num_heads, head_dim)
assert torch.allclose(output_tensor, output_tensor_no_sp)


if __name__ == "__main__":
torch.manual_seed(0)
seq_len = int((sys.argv[2]).split("=")[1])
test_transformer_flash_attention(seq_len=seq_len)

0 comments on commit b8c0b41

Please sign in to comment.