Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport AOTriton 0.7b to ROCm/PyTorch's release/2.4 branch #1587

Merged
merged 12 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .ci/docker/aotriton_version.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
0.6b
0.7b
manylinux_2_17
rocm6.2
d85155583dff1da3cfe9282c9e27db390ef52f64
e4ab195d2bd19e939c675a13280c29714c6ef9f2cf420690da150fa0cac043b1
9be04068c3c0857a4cfd17d7e39e71d0423ebac2
3e9e1959d23b93d78a08fcc5f868125dc3854dece32fd9458be9ef4467982291
4 changes: 2 additions & 2 deletions .ci/docker/common/install_aotriton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

set -ex

TARBALL='aotriton.tar.bz2'
TARBALL='aotriton.tar.gz'
# This read command alwasy returns with exit code 1
read -d "\n" VER MANYLINUX ROCMBASE PINNED_COMMIT SHA256 < aotriton_version.txt || true
ARCH=$(uname -m)
AOTRITON_INSTALL_PREFIX="$1"
AOTRITON_URL="https://github.com/ROCm/aotriton/releases/download/${VER}/aotriton-${VER}-${MANYLINUX}_${ARCH}-${ROCMBASE}-shared.tar.bz2"
AOTRITON_URL="https://github.com/ROCm/aotriton/releases/download/${VER}/aotriton-${VER}-${MANYLINUX}_${ARCH}-${ROCMBASE}-shared.tar.gz"

cd "${AOTRITON_INSTALL_PREFIX}"
# Must use -L to follow redirects
Expand Down
29 changes: 22 additions & 7 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1058,10 +1058,13 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
offset_t = at::empty({}, at::dtype(at::kLong).device(device));
} else {
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
seed_t = at::scalar_tensor(
at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
offset_t = at::scalar_tensor(
at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
#ifdef USE_ROCM
const auto options = at::dtype(at::kLong).device(at::kCUDA);
#else
const auto options = at::dtype(at::kLong);
#endif
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), options);
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), options);
}
} else {
// Not using dropout
Expand All @@ -1074,7 +1077,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)")
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs"
" (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
}

// AOTriton may accept aligned on logsumexp tensor in the future for better
Expand Down Expand Up @@ -1103,8 +1107,16 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_

using aotriton::v2::flash::attn_fwd;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::mk_philoxtensor;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16);
at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options());
const bool use_philox_state = in_capture_stream;
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t);
auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0;
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
hipError_t err; // TODO: Error handling
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
Expand All @@ -1114,8 +1126,11 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
mk_aotensor<2>(softmax_lse, "M"),
mk_aotensor(output_t, "Out"),
dropout_p,
use_dropout ? *seed_t.data_ptr<int64_t>() : 0,
use_dropout ? *offset_t.data_ptr<int64_t>() : 0,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
stream);
Expand Down
9 changes: 6 additions & 3 deletions aten/src/ATen/native/transformers/cuda/attention_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,8 @@ _efficient_attention_backward(
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)")
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs"
" (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
}
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
bool is_causal;
Expand All @@ -408,6 +409,7 @@ _efficient_attention_backward(
hipError_t err;
using aotriton::v2::flash::attn_bwd;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
err = attn_bwd(mk_aotensor(q_t, "q"),
Expand All @@ -424,8 +426,9 @@ _efficient_attention_backward(
mk_aotensor<2>(softmax_lse, "L"),
mk_aotensor<2>(delta, "delta"),
float(dropout_p),
rng_engine_inputs.seed_.val,
rng_engine_inputs.offset_.val,
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
#else
Expand Down
27 changes: 25 additions & 2 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
// Check that the gpu is capable of running flash attention
using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>;
auto dprops = at::cuda::getCurrentDeviceProperties();
#if USE_ROCM
#if USE_AOTRITON
auto stream = at::cuda::getCurrentCUDAStream().stream();
Expand All @@ -198,11 +199,19 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
}
return false;
}
c10::string_view arch(dprops->gcnArchName);
if (arch == "gfx1100") {
static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true;
if (!enable_navi3x) {
TORCH_WARN_ONCE("Flash attention support on Navi31 GPU is still experimental."
" Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1.");
return false;
}
}
#else
return false;
#endif
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm80, sm90>(dprops)) {
if (debug) {
TORCH_WARN(
Expand All @@ -222,6 +231,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
// Mem Efficient attention supports hardware in the range [sm_50, sm_90]
using sm50 = SMVersion<5, 0>;
using sm90 = SMVersion<9, 0>;
auto dprops = at::cuda::getCurrentDeviceProperties();
#if USE_ROCM
#if USE_AOTRITON
auto stream = at::cuda::getCurrentCUDAStream().stream();
Expand All @@ -233,11 +243,19 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
}
return false;
}
c10::string_view arch(dprops->gcnArchName);
if (arch == "gfx1100") {
static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true;
if (!enable_navi3x) {
TORCH_WARN_ONCE("Memory Efficient attention on Navi31 GPU is still experimental."
" Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1.");
return false;
}
}
#else
return false;
#endif
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm50, sm90>(dprops)) {
if (debug) {
TORCH_WARN(
Expand Down Expand Up @@ -597,6 +615,11 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
}
}
}
#if USE_ROCM
constexpr bool backend_supports_grouped_query_attention = false;
#else
constexpr bool backend_supports_grouped_query_attention = true;
#endif
if (has_only_dense_inputs(params)) {
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_and_num_heads_dense,
Expand Down
12 changes: 12 additions & 0 deletions aten/src/ATen/native/transformers/hip/aotriton_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,18 @@ aotriton::TensorView<Rank> mk_aotensor(const at::Tensor& q, c10::string_view ten
cast_dtype(q.dtype()));
}

inline aotriton::TensorView<0> mk_aoscalartensor(const at::Tensor& q)
{
return aotriton::TensorView<0>(reinterpret_cast<intptr_t>(q.data_ptr()),
cast_dtype(q.dtype()));
}

inline aotriton::TensorView<0> mk_philoxtensor(const int64_t* ptr)
{
return aotriton::TensorView<0>(reinterpret_cast<intptr_t>(ptr),
aotriton::DType::kUInt64); // AOTriton excepts unsigned int64
}

} // namespace aotriton_adapter

} // namespace sdp
Expand Down
34 changes: 26 additions & 8 deletions aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ void check_gpu_arch(hipStream_t stream) {
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)")
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs"
" (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
}
}

Expand Down Expand Up @@ -164,19 +165,23 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::Tensor seed_t, offset_t;

at::PhiloxCudaState philox_state;
bool use_philox_state = false;
if (p_dropout > 0.0) {
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = batch_size * num_heads * 32;
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
philox_state = gen->philox_cuda_state(counter_offset);
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
} else {
// See Note [CUDA Graph-safe RNG states] about the design
use_philox_state = true;
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
}
Expand All @@ -185,8 +190,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
} else {
seed_t = at::empty({}, at::dtype(at::kLong));
offset_t = at::empty({}, at::dtype(at::kLong));
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
}
}

Expand Down Expand Up @@ -219,9 +224,17 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head

hipError_t err; // TODO: Error handling
using aotriton::v2::flash::attn_fwd;
using aotriton::TensorView;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::mk_philoxtensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t);
auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0;
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
Expand All @@ -230,8 +243,11 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
mk_aotensor<2>(M, "M"),
mk_aotensor(output_t, "Out"),
p_dropout,
philox_args.seed_.val,
philox_args.offset_.val,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
stream);
Expand Down Expand Up @@ -419,6 +435,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
{
using aotriton::v2::flash::attn_bwd;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
err = attn_bwd(mk_aotensor(q_t, "q"),
Expand All @@ -435,8 +452,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
mk_aotensor<2>(softmax_lse_cont, "L"),
mk_aotensor<2>(delta, "delta"),
p_dropout,
philox_args.seed_.val,
philox_args.offset_.val,
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
}
Expand Down
7 changes: 5 additions & 2 deletions test/test_native_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,11 @@ def do_pad_all(tensors):
@torch.no_grad()
def test_native_multihead_self_attention(self, device, dtype, use_nt,
need_weights, average_attn_weights, use_padding, pad_all, fused):
if TEST_WITH_ROCM and use_nt:
self.skipTest("ROCM does not support nested tensors for Flash Attention for now.")
if TEST_WITH_ROCM:
if use_nt:
self.skipTest("ROCM does not support nested tensors for Flash Attention for now.")
if use_padding and not pad_all and fused:
self.skipTest("Large numerical errors on ROCM to investigate.")
for need_weights in (False, not pad_all):
with self.subTest(use_padding=use_padding, pad_all=pad_all,
use_nt=use_nt, need_weights=need_weights,
Expand Down
5 changes: 4 additions & 1 deletion test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2607,6 +2607,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
return
if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128:
torch.cuda.empty_cache() # Prevent memory fragmentation
if TEST_WITH_ROCM and is_causal and seq_len_q != seq_len_k:
self.skipTest("ROCm does not accept is_casual when seq_len_q != seq_len_k")
seed = 42
scale = scale if scale is None else (1 / head_dim)
n_heads = 4
Expand Down Expand Up @@ -2933,7 +2935,6 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)

@skipIfRocm # FIXME: "capturing stream has unjoined work"
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
@parametrize("batch_size", [1, 8])
@parametrize("seq_len_q", [256, 512, 1024])
Expand Down Expand Up @@ -2980,6 +2981,8 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d

if fused_kernel == SDPBackend.FLASH_ATTENTION and is_causal and seq_len_q != seq_len_k:
self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
if TEST_WITH_ROCM and is_causal and seq_len_q != seq_len_k:
self.skipTest("ROCm does not accept is_casual when seq_len_q != seq_len_k")

seed = 42
scale = scale if scale is None else (1 / head_dim)
Expand Down