Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Supporting fused kernels build using JIT #1188

Merged
merged 3 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 140 additions & 4 deletions megatron/fused_kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) 2024, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -11,14 +14,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file has been modified from its original version
#

import os
import pathlib
import subprocess

from pathlib import Path

srcpath = Path(__file__).parent.absolute()
import torch
from torch.utils import cpp_extension

# Setting this param to a list has a problem of generating different
# compilation commands (with different order of architectures) and
Expand All @@ -28,6 +32,138 @@
os.environ["TORCH_CUDA_ARCH_LIST"] = ""


def load(neox_args=None):

# Check if cuda 11 is installed for compute capability 8.0
cc_flag = []
if torch.version.hip is None:
_, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(
cpp_extension.CUDA_HOME
)
if int(bare_metal_major) >= 11:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if int(bare_metal_minor) >= 1:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_86,code=sm_86")
if int(bare_metal_minor) >= 4:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_87,code=sm_87")
if int(bare_metal_minor) >= 8:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_89,code=sm_89")
if int(bare_metal_major) >= 12:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")

# Build path
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / "build"
_create_build_dir(buildpath)

# Determine verbosity
verbose = True if neox_args is None else (neox_args.rank == 0)

# Helper function to build the kernels.
def _cpp_extention_load_helper(
name, sources, extra_cuda_flags, extra_include_paths
):
if torch.version.hip is not None:
extra_cuda_cflags = ["-O3"] + extra_cuda_flags + cc_flag
else:
extra_cuda_cflags = (
["-O3", "-gencode", "arch=compute_70,code=sm_70", "--use_fast_math"]
+ extra_cuda_flags
+ cc_flag
)

return cpp_extension.load(
name=name,
sources=sources,
build_directory=buildpath,
extra_cflags=[
"-O3",
],
extra_cuda_cflags=extra_cuda_cflags,
extra_include_paths=extra_include_paths,
verbose=verbose,
)

# ==============
# Fused softmax.
# ==============

if torch.version.hip is not None:
extra_include_paths = [os.path.abspath(srcpath)]
else:
extra_include_paths = []

if torch.version.hip is not None:
extra_cuda_flags = [
"-D__HIP_NO_HALF_OPERATORS__=1",
"-D__HIP_NO_HALF_CONVERSIONS__=1",
]
else:
extra_cuda_flags = [
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
]

# Upper triangular softmax.
sources = [
srcpath / "scaled_upper_triang_masked_softmax.cpp",
srcpath / "scaled_upper_triang_masked_softmax_cuda.cu",
]
scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_upper_triang_masked_softmax_cuda",
sources,
extra_cuda_flags,
extra_include_paths,
)
# Masked softmax.
sources = [
srcpath / "scaled_masked_softmax.cpp",
srcpath / "scaled_masked_softmax_cuda.cu",
]
scaled_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_masked_softmax_cuda", sources, extra_cuda_flags, extra_include_paths
)
# fused rope
sources = [
srcpath / "fused_rotary_positional_embedding.cpp",
srcpath / "fused_rotary_positional_embedding_cuda.cu",
]
fused_rotary_positional_embedding_cuda = _cpp_extention_load_helper(
"fused_rotary_positional_embedding_cuda",
sources,
extra_cuda_flags,
extra_include_paths,
)


def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]

return raw_output, bare_metal_major, bare_metal_minor


def _create_build_dir(buildpath):
try:
os.mkdir(buildpath)
except OSError:
if not os.path.isdir(buildpath):
print(f"Creation of the build directory {buildpath} failed")


def load_fused_kernels():
try:
import scaled_upper_triang_masked_softmax_cuda
Expand Down
2 changes: 2 additions & 0 deletions megatron/fused_kernels/scaled_masked_softmax_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_fp16.h>
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_fp16.h>
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
Expand Down
1 change: 1 addition & 0 deletions megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def finish_mpu_init():
or neox_args.scaled_masked_softmax_fusion
or neox_args.rope_fusion
):
fused_kernels.load(neox_args)
fused_kernels.load_fused_kernels()

if neox_args.lazy_mpu_init:
Expand Down
4 changes: 4 additions & 0 deletions tests/model/test_fused_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformers import BertTokenizer, GPT2Tokenizer
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from megatron.fused_kernels import load
import transformers

transformers.logging.set_verbosity(
Expand All @@ -33,6 +34,7 @@
reason="ModuleNotFoundError: No module named 'scaled_masked_softmax_cuda'"
)
def test_load_fused_kernels():
load()
try:
import scaled_masked_softmax_cuda
import scaled_upper_triang_masked_softmax_cuda
Expand All @@ -47,6 +49,7 @@ def test_load_fused_kernels():

@pytest.mark.xfail(reason="SystemExit: None")
def test_fused_softmax():
load()
from megatron.model.fused_softmax import FusedScaleMaskSoftmax, SoftmaxFusionTypes
from megatron.model.gpt2_model import (
gpt2_attention_mask_func as attention_mask_func,
Expand Down Expand Up @@ -149,6 +152,7 @@ def test_fused_softmax():

@pytest.mark.xfail(reason="SystemExit: None")
def test_fused_upper_triangle_mask_softmax():
load()
from megatron.model.gpt2_model import (
gpt2_attention_mask_func as attention_mask_func,
)
Expand Down
Loading