Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed-Triton for Inference #3748

Merged
merged 31 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6d291db
[squash] styoun/triton fp16 transformer (#530)
jeffra Jun 7, 2023
b978eae
Triton kernels and BERT inference using triton in float16 (#459)
stephen-youn Jun 14, 2023
f41e279
readme for blog
styoun Jun 16, 2023
2d499bf
typo in readme
styoun Jun 16, 2023
42b8de4
readme
styoun Jun 16, 2023
d16aaa2
readme
styoun Jun 16, 2023
b4fee89
readme
styoun Jun 16, 2023
468d698
readme
styoun Jun 16, 2023
d5fff4f
plots
styoun Jun 16, 2023
643a2fc
readme
styoun Jun 16, 2023
e6d44c9
readme
styoun Jun 16, 2023
fdb8706
readme
styoun Jun 16, 2023
b29157b
readme
styoun Jun 16, 2023
8078e04
readme
styoun Jun 16, 2023
f73bd90
typo in readme
styoun Jun 20, 2023
c2bd7dd
readme revision after the feedbacks
styoun Jun 20, 2023
19d4231
typo
styoun Jun 20, 2023
314c5ee
refined the writing in readme
styoun Jun 20, 2023
4da70a4
readme
styoun Jun 20, 2023
bad841c
readme
styoun Jun 20, 2023
47928a2
removed obsolete comments from matmul_ext.py
styoun Jun 20, 2023
db16fa0
typo
styoun Jun 21, 2023
39cbe47
Merge branch 'master' into staging-triton-bert-v1
jeffra Jun 22, 2023
219a1b8
readme change from pr comments
styoun Jun 22, 2023
072fe37
Merge branch 'staging-triton-bert-v1' of https://github.com/microsoft…
styoun Jun 22, 2023
7f7b76d
Merge branch 'master' into staging-triton-bert-v1
jeffra Jun 22, 2023
223ad1f
removed obsolete codes and comments
styoun Jun 22, 2023
afc34fa
Merge branch 'staging-triton-bert-v1' of https://github.com/microsoft…
styoun Jun 22, 2023
1eadebd
readme
styoun Jun 22, 2023
4cb8b37
Merge branch 'master' into staging-triton-bert-v1
stephen-youn Jun 22, 2023
2835e0d
Merge branch 'master' into staging-triton-bert-v1
jeffra Jun 23, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/formatting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:

- name: Install deepspeed
run: |
pip install .[dev,autotuning]
pip install .[dev,autotuning,triton]
ds_report

- name: Formatting checks
Expand Down
Binary file added blogs/assets/images/triton-bert-base-latency.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added blogs/assets/images/triton-bert-large-latency.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
95 changes: 95 additions & 0 deletions blogs/deepspeed-triton/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# DeepSpeed with Triton compiler

# 1. Overview

We have integrated [Triton](https://github.com/openai/triton), an open source compiler for GPU programming, into DeepSpeed, which further boosts the inference speed of BERT-like models in float16 precision.
By replacing some CUDA kernels or torch operators with Triton kernels, we achieved 1.14\~1.68x speedup (or 12\~41% latency reduction) for different models and GPUs, as shown in Table 1.

<div align="center">

| Hardware | Bert-base | Bert-large | Roberta-base | Roberta-large |
|----------|:------:|:------:|:------:|:------:|
| A100 |1.65x | 1.68x | 1.53x | 1.61x |
| V100 | 1.29x | 1.14x | 1.23x | 1.21x |

Table 1. The average speedup (see NOTE below for more detail)


</div>

For those transformer operators in float16, we have implemented kernels written in Triton language that replace ordinary CUDA kernels or torch operators.
The Triton kernels we implemented include softmax, layer-normalization, residual-addition and all the matrix multiplications except MLP layers (see NOTE below for details).
In our experiments, Triton kernels help to reduce the average latecy (over difference sequence lengths) by 6\~24% (depending on model and hardware) when compared to the latency with CUDA-only kernels.


Figures below show the latency reduction in more detail.
Figure 1 visualizes latency reduction in different sequence lengths in A100 GPU for Bert-base model.
The baseline (blue) is from Huggingface transformers without any kernel injection, the orange is from Deepspeed with CUDA-only kernels and the gray is from Deepspeed with Triton kernels.
Figure 2 shows the same plot for Bert-large model in A100 GPU.

<div align="center">

<img src="../assets/images/triton-bert-base-latency.png" width="500px" alt="triton-bert-base-latency"/>

*Figure 1: Normalized P90 latency for Bert-base model in A100 GPU across different sequence lengths*

<img src="../assets/images/triton-bert-large-latency.png" width="500px" alt="triton-bert-large-latency"/>

*Figure 2: Normalized P90 latency for Bert-large model in A100 GPU across different sequence lengths*

</div>


Next, we dive deeper into this new feature in DeepSpeed.

# 2. How to use Triton in Deepspeed

You can enable Triton compilers to optimize these kernels by setting a flag in the DeepSpeed config file.

```
pipe = pipeline('fill-mask', model='bert-base-cased', framework='pt', device=0)
pipe.model = deepspeed.init_inference(pipe.model,
dtype=torch.float16,
replace_with_kernel_inject=True,
enable_cuda_graph=True,
use_triton=True,
triton_autotune=True,
max_out_tokens=pipe.tokenizer.model_max_length)
```


## Running BERT inference with Triton kernels

We use an example of Bert-base here.

```python
pip install deepspeed[triton]

git clone https://github.com/microsoft/DeepSpeedExamples.git
cd DeepSpeedExamples/inference/huggingface/fill-mask

deepspeed --num_gpus 1 test-bert.py --triton
```

To run a performance benchmark, you can use the following command:

```python
pip install deepspeed[triton]

git clone https://github.com/microsoft/DeepSpeedExamples.git
cd DeepSpeedExamples/benchmarks/inference

deepspeed --num_gpus 1 triton-bert-benchmark.py --model bert-base-cased --dtype fp16 --kernel-inject --deepspeed --graphs --triton
```

# NOTE
<!-- **_NOTE:_** -->
* For more information on how to use DeepSpeed, please visit our [GitHub Page](https://github.com/microsoft/DeepSpeedExamples) and our [website](https://www.deepspeed.ai/), where you can find blog posts, tutorials, and documentation.

* This feature is currently only supported for BERT, Roberta and other BERT-like models, and not for text-generation models yet.

* To achieve the best performance with Triton optimization, you need to activate CUDA graph and ‘triton_autotune’ in the DeepSpeed config. CUDA graph prevents the overhead of JIT compilation and a deep call stack in Triton. ‘triton_autotune’ executes an initial step to find the most suitable parameters for Triton kernels, which may take some time.

* We used [Triton 2.0.0.post1 release](https://pypi.org/project/triton/2.0.0.post1/) in our experiments.

* In our experiments, we used a batch size of 1, a sequence length range of 8 to 512, and a ‘fill-mask’ task. Table 1 shows the average P90 latency over the entire sequence length range, while Figures 1 and 2 show the P90 latency for specific sub-ranges. The baseline is the Huggingface transformers without any optimization. The speedup is calculated as (baseline P90 latency)/(DeepSpeed-Triton P90 Latency). We found that the CUDA kernel in MLP performed better than the Triton kernel in our experiments, so we used a hybrid approach that combines both kernels when Triton is enabled in the DeepSpeed config.
6 changes: 6 additions & 0 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from torch.optim.lr_scheduler import _LRScheduler
from packaging import version as pkg_version

try:
import triton # noqa: F401
HAS_TRITON = True
except ImportError:
HAS_TRITON = False

from . import ops
from . import module_inject

Expand Down
19 changes: 19 additions & 0 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team

import torch
import deepspeed
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from pydantic import Field
Expand Down Expand Up @@ -152,6 +153,18 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
can run faster using the graph replay method.
"""

use_triton: bool = False
"""
Use this flag to use triton kernels for inference ops.
"""

triton_autotune: bool = False
stephen-youn marked this conversation as resolved.
Show resolved Hide resolved
"""
Use this flag to enable triton autotuning.
Turning it on is better for performance but increase the 1st runtime for
autotuning.
"""

zero: DeepSpeedZeroConfig = {}
"""
ZeRO configuration to use with the Inference Engine. Expects a dictionary
Expand Down Expand Up @@ -279,6 +292,12 @@ def moe_backward_compat(cls, field_value, values):
return DeepSpeedMoEConfig(moe=field_value)
return field_value

@validator("use_triton")
def has_triton(cls, field_value, values):
if field_value and not deepspeed.HAS_TRITON:
raise ValueError('Triton needs to be installed to use deepspeed with triton kernels')
return field_value

class Config:
# Get the str representation of the datatype for serialization
json_encoders = {torch.dtype: lambda x: str(x)}
1 change: 1 addition & 0 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ def forward(self, *inputs, **kwargs):
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)

else:
outputs = self.module(*inputs, **kwargs)

Expand Down
22 changes: 18 additions & 4 deletions deepspeed/model_implementations/transformers/ds_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
import deepspeed
if deepspeed.HAS_TRITON:
from deepspeed.ops.transformer.inference.triton.mlp import TritonMLP
from deepspeed.ops.transformer.inference.triton.attention import TritonSelfAttention

inference_module = None

Expand Down Expand Up @@ -55,14 +59,24 @@ def __init__(self,

if DeepSpeedTransformerInference.layer_id == 1:
log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
if deepspeed.HAS_TRITON and self.config.use_triton:
log_dist(f"Injecting Triton kernels ...", [0])

if self.config.bigscience_bloom:
self.attention = BloomSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, merge_count)
assert not self.config.use_triton
else:
self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups,
merge_count)
self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count,
mlp_extra_grouping)
if deepspeed.HAS_TRITON and self.config.use_triton:
self.attention = TritonSelfAttention(self.config)
else:
self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups,
merge_count)

if deepspeed.HAS_TRITON and self.config.use_triton:
self.mlp = TritonMLP(self.config)
else:
self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count,
mlp_extra_grouping)

device = get_accelerator().current_device_name() # if config.bigscience_bloom else 'cpu'
if self.config.set_empty_params:
Expand Down
14 changes: 13 additions & 1 deletion deepspeed/module_inject/containers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch

import deepspeed
from deepspeed.ops.transformer.inference.config import DeepSpeedInferenceConfig
from deepspeed.accelerator import get_accelerator

Expand Down Expand Up @@ -79,6 +80,10 @@ def __init__(self, policy, config, model_config, layer_id, child):
self.input_nb = None

self.mp_group = None
self.use_triton = False

# Triton
self.use_triton = config.use_triton and deepspeed.HAS_TRITON

def create_ds_model_config(self):
self.set_hidden_heads(*self.policy.get_hidden_heads())
Expand Down Expand Up @@ -110,7 +115,14 @@ def create_ds_model_config(self):
use_mup=self.use_mup,
return_single_tuple=self.return_single_tuple,
set_empty_params=self.config.set_empty_params,
transposed_mode=self.config.transposed_mode)
transposed_mode=self.config.transposed_mode,
use_triton=self.use_triton,
triton_autotune=self.config.triton_autotune)

if self.use_triton and deepspeed.HAS_TRITON:
if not self.config.triton_autotune:
from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul
fp16_matmul.skip_autotune()

return self.ds_model_config

Expand Down
1 change: 1 addition & 0 deletions deepspeed/module_inject/containers/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, **kwargs):
# All model specific things should be defined here instead of the base class.
self.return_tuple = True
self.triangular_masking = False
self.use_triton = kwargs['config'].use_triton and deepspeed.HAS_TRITON

def create_module(self, config=None):
_config = config if config is not None else self.ds_model_config
Expand Down
1 change: 1 addition & 0 deletions deepspeed/module_inject/containers/distil_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, **kwargs):
# All model specific things should be defined here instead of the base class.
self.triangular_masking = False
self.return_single_tuple = True
self.use_triton = kwargs['config'].use_triton and deepspeed.HAS_TRITON

def create_module(self, config=None):
_config = config if config is not None else self.ds_model_config
Expand Down
7 changes: 6 additions & 1 deletion deepspeed/ops/transformer/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class DeepSpeedInferenceConfig(TransformerConfig):
scale_attention: If true, both q and k are scaled by 1/sqrt(attention_heads) before attention computation.
return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor
bigscience_bloom: This flag is added temporarily for supporting the BLOOM-176B model architecture.
use_triton: This flag is to enable triton kernels in inference or not.
"""

def __init__(self,
Expand Down Expand Up @@ -77,7 +78,9 @@ def __init__(self,
scale_attn_by_inverse_layer_idx=False,
return_single_tuple=False,
set_empty_params=False,
transposed_mode=False):
transposed_mode=False,
use_triton=False,
triton_autotune=False):
super(DeepSpeedInferenceConfig,
self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads,
num_hidden_layers)
Expand Down Expand Up @@ -109,6 +112,8 @@ def __init__(self,
self.return_single_tuple = return_single_tuple
self.set_empty_params = set_empty_params
self.transposed_mode = transposed_mode
self.use_triton = use_triton
self.triton_autotune = triton_autotune

@classmethod
def from_dict(cls, json_object):
Expand Down
9 changes: 7 additions & 2 deletions deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
import deepspeed


class GELUGemmOp(BaseOp):
Expand All @@ -14,9 +15,13 @@ def __init__(self, config: DeepSpeedInferenceConfig):
super(GELUGemmOp, self).__init__(config)
try:
if self.config.dtype in [torch.float16, torch.int8]:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp16 # type: ignore
if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16:
from deepspeed.ops.transformer.inference.triton.ops import fused_gemm_gelu as _triton_fused_gemm_gelu
self.fused_gemm_gelu = _triton_fused_gemm_gelu # type: ignore
else:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_bf16
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_bf16 # type: ignore
else:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp32 # type: ignore
except AttributeError:
Expand Down
21 changes: 21 additions & 0 deletions deepspeed/ops/transformer/inference/op_binding/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
import deepspeed


class LinearOp(BaseOp):
Expand All @@ -14,6 +15,14 @@ def __init__(self, config: DeepSpeedInferenceConfig):
super(LinearOp, self).__init__(config)
try:
if self.config.dtype in [torch.float16, torch.int8]:
if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16:
from deepspeed.ops.transformer.inference.triton.ops import linear_func as _triton_linear_func
self.linear_func = _triton_linear_func
triton_autotune = config.triton_autotune and config.layer_id == 0
if triton_autotune:
__class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size)
else:
self.linear_func = self.inference_module.linear_layer_fp16
self.linear_func = self.inference_module.linear_layer_fp16
elif self.config.dtype == torch.bfloat16:
self.linear_func = self.inference_module.linear_layer_bf16
Expand All @@ -37,3 +46,15 @@ def forward(self,
qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads,
self.config.transposed_mode)
return qkv_out

@staticmethod
def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16):
from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul
seqlen = [(min_seqlen + i)
for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)]
Fp16Matmul._read_autotune_table()
for N in seqlen:
A = torch.randn((N, hidden_size), dtype=dtype, device='cuda')
B = torch.randn((hidden_size, 3 * hidden_size), dtype=dtype, device='cuda')
matmul(A, B)
Fp16Matmul._update_autotune_table()
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def __init__(self, config: DeepSpeedInferenceConfig):
super(MLPGemmOp, self).__init__(config)
try:
if self.config.norm_type == NormType.LayerNorm:
if self.config.dtype in [torch.float16, torch.int8]:
if self.config.dtype in [
torch.float16, torch.int8
]: # non-triton cuda kernel has a higher performance in MLP than mlp_gemm_func in triton.ops
self.mlp_gemm_func = self.inference_module.mlp_gemm_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.mlp_gemm_func = self.inference_module.mlp_gemm_bf16
Expand Down
Loading