Skip to content

Commit

Permalink
DeepSpeed-Triton for Inference (#3748)
Browse files Browse the repository at this point in the history
Co-authored-by: Stephen Youn <[email protected]>
Co-authored-by: Arash Bakhtiari <[email protected]>
Co-authored-by: Cheng Li <[email protected]>
Co-authored-by: Ethan Doe <[email protected]>
Co-authored-by: yidoe <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
7 people authored Jun 23, 2023
1 parent 2c62cb4 commit 4dc65f7
Show file tree
Hide file tree
Showing 38 changed files with 2,345 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/formatting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Install deepspeed
run: |
pip install .[dev,autotuning]
pip install .[dev,autotuning,triton]
ds_report
- name: Formatting checks
Expand Down
Binary file added blogs/assets/images/triton-bert-base-latency.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added blogs/assets/images/triton-bert-large-latency.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
95 changes: 95 additions & 0 deletions blogs/deepspeed-triton/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# DeepSpeed with Triton compiler

# 1. Overview

We have integrated [Triton](https://github.com/openai/triton), an open source compiler for GPU programming, into DeepSpeed, which further boosts the inference speed of BERT-like models in float16 precision.
By replacing some CUDA kernels or torch operators with Triton kernels, we achieved 1.14\~1.68x speedup (or 12\~41% latency reduction) for different models and GPUs, as shown in Table 1.

<div align="center">

| Hardware | Bert-base | Bert-large | Roberta-base | Roberta-large |
|----------|:------:|:------:|:------:|:------:|
| A100 |1.65x | 1.68x | 1.53x | 1.61x |
| V100 | 1.29x | 1.14x | 1.23x | 1.21x |

Table 1. The average speedup (see NOTE below for more detail)


</div>

For those transformer operators in float16, we have implemented kernels written in Triton language that replace ordinary CUDA kernels or torch operators.
The Triton kernels we implemented include softmax, layer-normalization, residual-addition and all the matrix multiplications except MLP layers (see NOTE below for details).
In our experiments, Triton kernels help to reduce the average latecy (over difference sequence lengths) by 6\~24% (depending on model and hardware) when compared to the latency with CUDA-only kernels.


Figures below show the latency reduction in more detail.
Figure 1 visualizes latency reduction in different sequence lengths in A100 GPU for Bert-base model.
The baseline (blue) is from Huggingface transformers without any kernel injection, the orange is from Deepspeed with CUDA-only kernels and the gray is from Deepspeed with Triton kernels.
Figure 2 shows the same plot for Bert-large model in A100 GPU.

<div align="center">

<img src="../assets/images/triton-bert-base-latency.png" width="500px" alt="triton-bert-base-latency"/>

*Figure 1: Normalized P90 latency for Bert-base model in A100 GPU across different sequence lengths*

<img src="../assets/images/triton-bert-large-latency.png" width="500px" alt="triton-bert-large-latency"/>

*Figure 2: Normalized P90 latency for Bert-large model in A100 GPU across different sequence lengths*

</div>


Next, we dive deeper into this new feature in DeepSpeed.

# 2. How to use Triton in Deepspeed

You can enable Triton compilers to optimize these kernels by setting a flag in the DeepSpeed config file.

```
pipe = pipeline('fill-mask', model='bert-base-cased', framework='pt', device=0)
pipe.model = deepspeed.init_inference(pipe.model,
dtype=torch.float16,
replace_with_kernel_inject=True,
enable_cuda_graph=True,
use_triton=True,
triton_autotune=True,
max_out_tokens=pipe.tokenizer.model_max_length)
```


## Running BERT inference with Triton kernels

We use an example of Bert-base here.

```python
pip install deepspeed[triton]

git clone https://github.com/microsoft/DeepSpeedExamples.git
cd DeepSpeedExamples/inference/huggingface/fill-mask

deepspeed --num_gpus 1 test-bert.py --triton
```

To run a performance benchmark, you can use the following command:

```python
pip install deepspeed[triton]

git clone https://github.com/microsoft/DeepSpeedExamples.git
cd DeepSpeedExamples/benchmarks/inference

deepspeed --num_gpus 1 triton-bert-benchmark.py --model bert-base-cased --dtype fp16 --kernel-inject --deepspeed --graphs --triton
```

# NOTE
<!-- **_NOTE:_** -->
* For more information on how to use DeepSpeed, please visit our [GitHub Page](https://github.com/microsoft/DeepSpeedExamples) and our [website](https://www.deepspeed.ai/), where you can find blog posts, tutorials, and documentation.

* This feature is currently only supported for BERT, Roberta and other BERT-like models, and not for text-generation models yet.

* To achieve the best performance with Triton optimization, you need to activate CUDA graph and ‘triton_autotune’ in the DeepSpeed config. CUDA graph prevents the overhead of JIT compilation and a deep call stack in Triton. ‘triton_autotune’ executes an initial step to find the most suitable parameters for Triton kernels, which may take some time.

* We used [Triton 2.0.0.post1 release](https://pypi.org/project/triton/2.0.0.post1/) in our experiments.

* In our experiments, we used a batch size of 1, a sequence length range of 8 to 512, and a ‘fill-mask’ task. Table 1 shows the average P90 latency over the entire sequence length range, while Figures 1 and 2 show the P90 latency for specific sub-ranges. The baseline is the Huggingface transformers without any optimization. The speedup is calculated as (baseline P90 latency)/(DeepSpeed-Triton P90 Latency). We found that the CUDA kernel in MLP performed better than the Triton kernel in our experiments, so we used a hybrid approach that combines both kernels when Triton is enabled in the DeepSpeed config.
6 changes: 6 additions & 0 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from torch.optim.lr_scheduler import _LRScheduler
from packaging import version as pkg_version

try:
import triton # noqa: F401
HAS_TRITON = True
except ImportError:
HAS_TRITON = False

from . import ops
from . import module_inject

Expand Down
19 changes: 19 additions & 0 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team

import torch
import deepspeed
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from pydantic import Field
Expand Down Expand Up @@ -152,6 +153,18 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
can run faster using the graph replay method.
"""

use_triton: bool = False
"""
Use this flag to use triton kernels for inference ops.
"""

triton_autotune: bool = False
"""
Use this flag to enable triton autotuning.
Turning it on is better for performance but increase the 1st runtime for
autotuning.
"""

zero: DeepSpeedZeroConfig = {}
"""
ZeRO configuration to use with the Inference Engine. Expects a dictionary
Expand Down Expand Up @@ -279,6 +292,12 @@ def moe_backward_compat(cls, field_value, values):
return DeepSpeedMoEConfig(moe=field_value)
return field_value

@validator("use_triton")
def has_triton(cls, field_value, values):
if field_value and not deepspeed.HAS_TRITON:
raise ValueError('Triton needs to be installed to use deepspeed with triton kernels')
return field_value

class Config:
# Get the str representation of the datatype for serialization
json_encoders = {torch.dtype: lambda x: str(x)}
1 change: 1 addition & 0 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ def forward(self, *inputs, **kwargs):
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)

else:
outputs = self.module(*inputs, **kwargs)

Expand Down
22 changes: 18 additions & 4 deletions deepspeed/model_implementations/transformers/ds_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
import deepspeed
if deepspeed.HAS_TRITON:
from deepspeed.ops.transformer.inference.triton.mlp import TritonMLP
from deepspeed.ops.transformer.inference.triton.attention import TritonSelfAttention

inference_module = None

Expand Down Expand Up @@ -55,14 +59,24 @@ def __init__(self,

if DeepSpeedTransformerInference.layer_id == 1:
log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
if deepspeed.HAS_TRITON and self.config.use_triton:
log_dist(f"Injecting Triton kernels ...", [0])

if self.config.bigscience_bloom:
self.attention = BloomSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, merge_count)
assert not self.config.use_triton
else:
self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups,
merge_count)
self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count,
mlp_extra_grouping)
if deepspeed.HAS_TRITON and self.config.use_triton:
self.attention = TritonSelfAttention(self.config)
else:
self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups,
merge_count)

if deepspeed.HAS_TRITON and self.config.use_triton:
self.mlp = TritonMLP(self.config)
else:
self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count,
mlp_extra_grouping)

device = get_accelerator().current_device_name() # if config.bigscience_bloom else 'cpu'
if self.config.set_empty_params:
Expand Down
14 changes: 13 additions & 1 deletion deepspeed/module_inject/containers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch

import deepspeed
from deepspeed.ops.transformer.inference.config import DeepSpeedInferenceConfig
from deepspeed.accelerator import get_accelerator

Expand Down Expand Up @@ -79,6 +80,10 @@ def __init__(self, policy, config, model_config, layer_id, child):
self.input_nb = None

self.mp_group = None
self.use_triton = False

# Triton
self.use_triton = config.use_triton and deepspeed.HAS_TRITON

def create_ds_model_config(self):
self.set_hidden_heads(*self.policy.get_hidden_heads())
Expand Down Expand Up @@ -110,7 +115,14 @@ def create_ds_model_config(self):
use_mup=self.use_mup,
return_single_tuple=self.return_single_tuple,
set_empty_params=self.config.set_empty_params,
transposed_mode=self.config.transposed_mode)
transposed_mode=self.config.transposed_mode,
use_triton=self.use_triton,
triton_autotune=self.config.triton_autotune)

if self.use_triton and deepspeed.HAS_TRITON:
if not self.config.triton_autotune:
from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul
fp16_matmul.skip_autotune()

return self.ds_model_config

Expand Down
1 change: 1 addition & 0 deletions deepspeed/module_inject/containers/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, **kwargs):
# All model specific things should be defined here instead of the base class.
self.return_tuple = True
self.triangular_masking = False
self.use_triton = kwargs['config'].use_triton and deepspeed.HAS_TRITON

def create_module(self, config=None):
_config = config if config is not None else self.ds_model_config
Expand Down
1 change: 1 addition & 0 deletions deepspeed/module_inject/containers/distil_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, **kwargs):
# All model specific things should be defined here instead of the base class.
self.triangular_masking = False
self.return_single_tuple = True
self.use_triton = kwargs['config'].use_triton and deepspeed.HAS_TRITON

def create_module(self, config=None):
_config = config if config is not None else self.ds_model_config
Expand Down
7 changes: 6 additions & 1 deletion deepspeed/ops/transformer/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class DeepSpeedInferenceConfig(TransformerConfig):
scale_attention: If true, both q and k are scaled by 1/sqrt(attention_heads) before attention computation.
return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor
bigscience_bloom: This flag is added temporarily for supporting the BLOOM-176B model architecture.
use_triton: This flag is to enable triton kernels in inference or not.
"""

def __init__(self,
Expand Down Expand Up @@ -77,7 +78,9 @@ def __init__(self,
scale_attn_by_inverse_layer_idx=False,
return_single_tuple=False,
set_empty_params=False,
transposed_mode=False):
transposed_mode=False,
use_triton=False,
triton_autotune=False):
super(DeepSpeedInferenceConfig,
self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads,
num_hidden_layers)
Expand Down Expand Up @@ -109,6 +112,8 @@ def __init__(self,
self.return_single_tuple = return_single_tuple
self.set_empty_params = set_empty_params
self.transposed_mode = transposed_mode
self.use_triton = use_triton
self.triton_autotune = triton_autotune

@classmethod
def from_dict(cls, json_object):
Expand Down
9 changes: 7 additions & 2 deletions deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
import deepspeed


class GELUGemmOp(BaseOp):
Expand All @@ -14,9 +15,13 @@ def __init__(self, config: DeepSpeedInferenceConfig):
super(GELUGemmOp, self).__init__(config)
try:
if self.config.dtype in [torch.float16, torch.int8]:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp16 # type: ignore
if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16:
from deepspeed.ops.transformer.inference.triton.ops import fused_gemm_gelu as _triton_fused_gemm_gelu
self.fused_gemm_gelu = _triton_fused_gemm_gelu # type: ignore
else:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_bf16
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_bf16 # type: ignore
else:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp32 # type: ignore
except AttributeError:
Expand Down
21 changes: 21 additions & 0 deletions deepspeed/ops/transformer/inference/op_binding/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
import deepspeed


class LinearOp(BaseOp):
Expand All @@ -14,6 +15,14 @@ def __init__(self, config: DeepSpeedInferenceConfig):
super(LinearOp, self).__init__(config)
try:
if self.config.dtype in [torch.float16, torch.int8]:
if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16:
from deepspeed.ops.transformer.inference.triton.ops import linear_func as _triton_linear_func
self.linear_func = _triton_linear_func
triton_autotune = config.triton_autotune and config.layer_id == 0
if triton_autotune:
__class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size)
else:
self.linear_func = self.inference_module.linear_layer_fp16
self.linear_func = self.inference_module.linear_layer_fp16
elif self.config.dtype == torch.bfloat16:
self.linear_func = self.inference_module.linear_layer_bf16
Expand All @@ -37,3 +46,15 @@ def forward(self,
qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads,
self.config.transposed_mode)
return qkv_out

@staticmethod
def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16):
from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul
seqlen = [(min_seqlen + i)
for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)]
Fp16Matmul._read_autotune_table()
for N in seqlen:
A = torch.randn((N, hidden_size), dtype=dtype, device='cuda')
B = torch.randn((hidden_size, 3 * hidden_size), dtype=dtype, device='cuda')
matmul(A, B)
Fp16Matmul._update_autotune_table()
4 changes: 3 additions & 1 deletion deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def __init__(self, config: DeepSpeedInferenceConfig):
super(MLPGemmOp, self).__init__(config)
try:
if self.config.norm_type == NormType.LayerNorm:
if self.config.dtype in [torch.float16, torch.int8]:
if self.config.dtype in [
torch.float16, torch.int8
]: # non-triton cuda kernel has a higher performance in MLP than mlp_gemm_func in triton.ops
self.mlp_gemm_func = self.inference_module.mlp_gemm_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.mlp_gemm_func = self.inference_module.mlp_gemm_bf16
Expand Down
Loading

0 comments on commit 4dc65f7

Please sign in to comment.