Skip to content

Commit

Permalink
Backport AOTriton 0.7b to ROCm/PyTorch's release/2.4 branch (#1587)
Browse files Browse the repository at this point in the history
It's a cherry-picked version of
pytorch#134498

(WIP: need some work on the UT part)
- [edit] Discussed with Xinya and this part is already part of current
PR.
  • Loading branch information
xinyazhang authored Sep 13, 2024
1 parent c1b6f60 commit 83504b7
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 28 deletions.
6 changes: 3 additions & 3 deletions .ci/docker/aotriton_version.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
0.6b
0.7b
manylinux_2_17
rocm6.2
d85155583dff1da3cfe9282c9e27db390ef52f64
e4ab195d2bd19e939c675a13280c29714c6ef9f2cf420690da150fa0cac043b1
9be04068c3c0857a4cfd17d7e39e71d0423ebac2
3e9e1959d23b93d78a08fcc5f868125dc3854dece32fd9458be9ef4467982291
4 changes: 2 additions & 2 deletions .ci/docker/common/install_aotriton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

set -ex

TARBALL='aotriton.tar.bz2'
TARBALL='aotriton.tar.gz'
# This read command alwasy returns with exit code 1
read -d "\n" VER MANYLINUX ROCMBASE PINNED_COMMIT SHA256 < aotriton_version.txt || true
ARCH=$(uname -m)
AOTRITON_INSTALL_PREFIX="$1"
AOTRITON_URL="https://github.com/ROCm/aotriton/releases/download/${VER}/aotriton-${VER}-${MANYLINUX}_${ARCH}-${ROCMBASE}-shared.tar.bz2"
AOTRITON_URL="https://github.com/ROCm/aotriton/releases/download/${VER}/aotriton-${VER}-${MANYLINUX}_${ARCH}-${ROCMBASE}-shared.tar.gz"

cd "${AOTRITON_INSTALL_PREFIX}"
# Must use -L to follow redirects
Expand Down
29 changes: 22 additions & 7 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1058,10 +1058,13 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
offset_t = at::empty({}, at::dtype(at::kLong).device(device));
} else {
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
seed_t = at::scalar_tensor(
at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
offset_t = at::scalar_tensor(
at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
#ifdef USE_ROCM
const auto options = at::dtype(at::kLong).device(at::kCUDA);
#else
const auto options = at::dtype(at::kLong);
#endif
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), options);
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), options);
}
} else {
// Not using dropout
Expand All @@ -1074,7 +1077,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)")
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs"
" (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
}

// AOTriton may accept aligned on logsumexp tensor in the future for better
Expand Down Expand Up @@ -1103,8 +1107,16 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_

using aotriton::v2::flash::attn_fwd;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::mk_philoxtensor;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16);
at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options());
const bool use_philox_state = in_capture_stream;
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t);
auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0;
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
hipError_t err; // TODO: Error handling
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
Expand All @@ -1114,8 +1126,11 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
mk_aotensor<2>(softmax_lse, "M"),
mk_aotensor(output_t, "Out"),
dropout_p,
use_dropout ? *seed_t.data_ptr<int64_t>() : 0,
use_dropout ? *offset_t.data_ptr<int64_t>() : 0,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
stream);
Expand Down
9 changes: 6 additions & 3 deletions aten/src/ATen/native/transformers/cuda/attention_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,8 @@ _efficient_attention_backward(
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)")
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs"
" (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
}
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
bool is_causal;
Expand All @@ -408,6 +409,7 @@ _efficient_attention_backward(
hipError_t err;
using aotriton::v2::flash::attn_bwd;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
err = attn_bwd(mk_aotensor(q_t, "q"),
Expand All @@ -424,8 +426,9 @@ _efficient_attention_backward(
mk_aotensor<2>(softmax_lse, "L"),
mk_aotensor<2>(delta, "delta"),
float(dropout_p),
rng_engine_inputs.seed_.val,
rng_engine_inputs.offset_.val,
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
#else
Expand Down
27 changes: 25 additions & 2 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
// Check that the gpu is capable of running flash attention
using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>;
auto dprops = at::cuda::getCurrentDeviceProperties();
#if USE_ROCM
#if USE_AOTRITON
auto stream = at::cuda::getCurrentCUDAStream().stream();
Expand All @@ -198,11 +199,19 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
}
return false;
}
c10::string_view arch(dprops->gcnArchName);
if (arch == "gfx1100") {
static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true;
if (!enable_navi3x) {
TORCH_WARN_ONCE("Flash attention support on Navi31 GPU is still experimental."
" Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1.");
return false;
}
}
#else
return false;
#endif
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm80, sm90>(dprops)) {
if (debug) {
TORCH_WARN(
Expand All @@ -222,6 +231,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
// Mem Efficient attention supports hardware in the range [sm_50, sm_90]
using sm50 = SMVersion<5, 0>;
using sm90 = SMVersion<9, 0>;
auto dprops = at::cuda::getCurrentDeviceProperties();
#if USE_ROCM
#if USE_AOTRITON
auto stream = at::cuda::getCurrentCUDAStream().stream();
Expand All @@ -233,11 +243,19 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
}
return false;
}
c10::string_view arch(dprops->gcnArchName);
if (arch == "gfx1100") {
static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true;
if (!enable_navi3x) {
TORCH_WARN_ONCE("Memory Efficient attention on Navi31 GPU is still experimental."
" Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1.");
return false;
}
}
#else
return false;
#endif
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm50, sm90>(dprops)) {
if (debug) {
TORCH_WARN(
Expand Down Expand Up @@ -597,6 +615,11 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
}
}
}
#if USE_ROCM
constexpr bool backend_supports_grouped_query_attention = false;
#else
constexpr bool backend_supports_grouped_query_attention = true;
#endif
if (has_only_dense_inputs(params)) {
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_and_num_heads_dense,
Expand Down
12 changes: 12 additions & 0 deletions aten/src/ATen/native/transformers/hip/aotriton_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,18 @@ aotriton::TensorView<Rank> mk_aotensor(const at::Tensor& q, c10::string_view ten
cast_dtype(q.dtype()));
}

inline aotriton::TensorView<0> mk_aoscalartensor(const at::Tensor& q)
{
return aotriton::TensorView<0>(reinterpret_cast<intptr_t>(q.data_ptr()),
cast_dtype(q.dtype()));
}

inline aotriton::TensorView<0> mk_philoxtensor(const int64_t* ptr)
{
return aotriton::TensorView<0>(reinterpret_cast<intptr_t>(ptr),
aotriton::DType::kUInt64); // AOTriton excepts unsigned int64
}

} // namespace aotriton_adapter

} // namespace sdp
Expand Down
34 changes: 26 additions & 8 deletions aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ void check_gpu_arch(hipStream_t stream) {
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)")
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs"
" (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
}
}

Expand Down Expand Up @@ -164,19 +165,23 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::Tensor seed_t, offset_t;

at::PhiloxCudaState philox_state;
bool use_philox_state = false;
if (p_dropout > 0.0) {
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = batch_size * num_heads * 32;
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
philox_state = gen->philox_cuda_state(counter_offset);
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
} else {
// See Note [CUDA Graph-safe RNG states] about the design
use_philox_state = true;
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
}
Expand All @@ -185,8 +190,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
} else {
seed_t = at::empty({}, at::dtype(at::kLong));
offset_t = at::empty({}, at::dtype(at::kLong));
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
}
}

Expand Down Expand Up @@ -219,9 +224,17 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head

hipError_t err; // TODO: Error handling
using aotriton::v2::flash::attn_fwd;
using aotriton::TensorView;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::mk_philoxtensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t);
auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0;
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
Expand All @@ -230,8 +243,11 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
mk_aotensor<2>(M, "M"),
mk_aotensor(output_t, "Out"),
p_dropout,
philox_args.seed_.val,
philox_args.offset_.val,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
stream);
Expand Down Expand Up @@ -419,6 +435,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
{
using aotriton::v2::flash::attn_bwd;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
err = attn_bwd(mk_aotensor(q_t, "q"),
Expand All @@ -435,8 +452,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
mk_aotensor<2>(softmax_lse_cont, "L"),
mk_aotensor<2>(delta, "delta"),
p_dropout,
philox_args.seed_.val,
philox_args.offset_.val,
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
}
Expand Down
7 changes: 5 additions & 2 deletions test/test_native_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,11 @@ def do_pad_all(tensors):
@torch.no_grad()
def test_native_multihead_self_attention(self, device, dtype, use_nt,
need_weights, average_attn_weights, use_padding, pad_all, fused):
if TEST_WITH_ROCM and use_nt:
self.skipTest("ROCM does not support nested tensors for Flash Attention for now.")
if TEST_WITH_ROCM:
if use_nt:
self.skipTest("ROCM does not support nested tensors for Flash Attention for now.")
if use_padding and not pad_all and fused:
self.skipTest("Large numerical errors on ROCM to investigate.")
for need_weights in (False, not pad_all):
with self.subTest(use_padding=use_padding, pad_all=pad_all,
use_nt=use_nt, need_weights=need_weights,
Expand Down
5 changes: 4 additions & 1 deletion test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2607,6 +2607,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
return
if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128:
torch.cuda.empty_cache() # Prevent memory fragmentation
if TEST_WITH_ROCM and is_causal and seq_len_q != seq_len_k:
self.skipTest("ROCm does not accept is_casual when seq_len_q != seq_len_k")
seed = 42
scale = scale if scale is None else (1 / head_dim)
n_heads = 4
Expand Down Expand Up @@ -2933,7 +2935,6 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)

@skipIfRocm # FIXME: "capturing stream has unjoined work"
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
@parametrize("batch_size", [1, 8])
@parametrize("seq_len_q", [256, 512, 1024])
Expand Down Expand Up @@ -2980,6 +2981,8 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d

if fused_kernel == SDPBackend.FLASH_ATTENTION and is_causal and seq_len_q != seq_len_k:
self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
if TEST_WITH_ROCM and is_causal and seq_len_q != seq_len_k:
self.skipTest("ROCm does not accept is_casual when seq_len_q != seq_len_k")

seed = 42
scale = scale if scale is None else (1 / head_dim)
Expand Down

0 comments on commit 83504b7

Please sign in to comment.