Skip to content

Commit

Permalink
PR #12255: [GPU] Fix FMHA hangs by moving compilation to thunk initia…
Browse files Browse the repository at this point in the history
…lization.

Imported from GitHub PR #12255

Applies the same fix as in #12228 to FMHA.
Copybara import of the project:

--
70a4282 by Ilia Sergachev <[email protected]>:

[GPU] Fix FMHA hangs by moving compilation to thunk initialization.

Merging this change closes #12255

COPYBARA_INTEGRATE_REVIEW=#12255 from openxla:fix_fmha_hang 70a4282
PiperOrigin-RevId: 633217196
  • Loading branch information
sergachev authored and copybara-github committed May 13, 2024
1 parent 9f9d7f6 commit c7a2865
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 43 deletions.
89 changes: 46 additions & 43 deletions xla/service/gpu/gpu_fused_mha_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,29 +68,13 @@ absl::Status RunFusedMHA(GpufMHAParams params, se::Stream *stream,
dropout_rate = *params.config->dropout_rate;
}

double scale = 1.0;
if (params.config->fmha_scale) {
scale = *params.config->fmha_scale;
}

std::optional<int64_t> seed;
if (params.config->seed) {
seed = *params.config->seed;
}
TF_ASSIGN_OR_RETURN(
se::dnn::FMHAMaskKind mask_type,
GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(params.config->mask_type));
se::dnn::FusedMHAOp::Config config{scale,
params.config->lhs_bmm1,
params.config->rhs_bmm1,
params.config->rhs_bmm2,
params.config->intermediate_lhs_bmm2,
params.config->output,
params.config->bias,
params.config->activation,
dropout_rate,
seed,
mask_type};

TF_ASSIGN_OR_RETURN(se::dnn::FusedMHAOp::Config config,
params.config->AsDnnFusedMHAOpConfig());
TF_ASSIGN_OR_RETURN(auto *runner,
lazy_runner->GetOrCreateRunner(config, stream));
return (*runner)(stream, options.profile_result, scratch_memory,
Expand Down Expand Up @@ -183,35 +167,13 @@ absl::Status RunFusedMHABackward(
dropout_rate = *params.config->dropout_rate;
}

double scale = 1.0;
if (params.config->fmha_scale) {
scale = *params.config->fmha_scale;
}

std::optional<int64_t> seed;
if (params.config->seed) {
seed = *params.config->seed;
}

TF_ASSIGN_OR_RETURN(
se::dnn::FMHAMaskKind mask_type,
GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(params.config->mask_type));
se::dnn::FusedMHABackwardOp::Config config{scale,
params.config->bmm1_grad_gemm1_rhs,
params.config->bmm1_grad_gemm2_rhs,
params.config->bmm2_grad_gemm1_lhs,
params.config->bmm2_grad_gemm2_rhs,
params.config->d_output,
params.config->d_bmm1_lhs,
params.config->d_bmm1_rhs,
params.config->d_bmm2_rhs,
params.config->d_s,
params.config->d_bias,
params.config->fwd_output,
params.config->bias,
dropout_rate,
seed,
mask_type};
TF_ASSIGN_OR_RETURN(se::dnn::FusedMHABackwardOp::Config config,
params.config->AsDnnFusedMHABackwardOpConfig());
TF_ASSIGN_OR_RETURN(auto *runner,
lazy_runner->GetOrCreateRunner(config, stream));
// TODO: pass in real softmax_sum, dQ_accum, fwd_output
Expand Down Expand Up @@ -404,6 +366,21 @@ absl::Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams &params,
return config;
}

absl::StatusOr<se::dnn::FusedMHAOp::Config>
GpufMHAConfig::AsDnnFusedMHAOpConfig() const {
double scale = 1.0;
if (fmha_scale.has_value()) {
scale = *fmha_scale;
}
TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind mask_type,
GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(mask_type));

return se::dnn::FusedMHAOp::Config{
scale, lhs_bmm1, rhs_bmm1, rhs_bmm2, intermediate_lhs_bmm2,
output, bias, activation, dropout_rate, seed,
mask_type};
}

/*static*/ absl::StatusOr<GpufMHABackwardConfig> GpufMHABackwardConfig::For(
const GpufMHABackwardDescriptor &desc) {
// Get shapes from desc.
Expand Down Expand Up @@ -546,6 +523,32 @@ absl::Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams &params,
return config;
}

absl::StatusOr<se::dnn::FusedMHABackwardOp::Config>
GpufMHABackwardConfig::AsDnnFusedMHABackwardOpConfig() const {
double scale = 1.0;
if (fmha_scale.has_value()) {
scale = *fmha_scale;
}
TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind mask_type,
GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(mask_type));
return se::dnn::FusedMHABackwardOp::Config{scale,
bmm1_grad_gemm1_rhs,
bmm1_grad_gemm2_rhs,
bmm2_grad_gemm1_lhs,
bmm2_grad_gemm2_rhs,
d_output,
d_bmm1_lhs,
d_bmm1_rhs,
d_bmm2_rhs,
d_s,
d_bias,
fwd_output,
bias,
dropout_rate,
seed,
mask_type};
}

/*static*/ absl::StatusOr<GpufMHAParams> GpufMHAParams::For(
const GpufMHAConfig &config, se::DeviceMemoryBase lhs_bmm1_buffer,
se::DeviceMemoryBase rhs_bmm1_buffer, se::DeviceMemoryBase rhs_bmm2_buffer,
Expand Down
8 changes: 8 additions & 0 deletions xla/service/gpu/gpu_fused_mha_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,14 @@ struct GpufMHABackwardDescriptor {
std::optional<Shape> d_bias_shape;
std::optional<Shape> bias_shape;
};

// Structure to describe static properties of a GPU fused Multi-Headed
// Attention.
struct GpufMHAConfig {
static absl::StatusOr<GpufMHAConfig> For(const GpufMHADescriptor& fmha_desc);

absl::StatusOr<se::dnn::FusedMHAOp::Config> AsDnnFusedMHAOpConfig() const;

PrimitiveType
input_type; // Capture the primitive type of one of the inputs of BMM1
PrimitiveType output_type;
Expand Down Expand Up @@ -133,6 +137,10 @@ struct GpufMHAConfig {
struct GpufMHABackwardConfig {
static absl::StatusOr<GpufMHABackwardConfig> For(
const GpufMHABackwardDescriptor& fmha_desc);

absl::StatusOr<se::dnn::FusedMHABackwardOp::Config>
AsDnnFusedMHABackwardOpConfig() const;

PrimitiveType
input_type; // Capture the primitive type of one of the inputs of BMM1
PrimitiveType output_type;
Expand Down
14 changes: 14 additions & 0 deletions xla/service/gpu/runtime/fused_mha_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ std::optional<se::DeviceMemoryBase> AssignBufferIfNotNull(
: std::nullopt;
}

absl::Status FusedMHAThunk::Initialize(const InitializeParams& params) {
se::dnn::LazyOpRunner<se::dnn::FusedMHAOp>* lazy_runner =
GetOrCreateRunner(params.stream).AsFusedMHARunner();
TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHAOpConfig());
return lazy_runner->GetOrCreateRunner(config, params.stream).status();
}

absl::Status FusedMHAThunk::ExecuteOnStream(const ExecuteParams& params) {
const auto& buffer_allocations = *params.buffer_allocations;
se::DeviceMemoryBase lhs_bmm1_buffer =
Expand Down Expand Up @@ -143,6 +150,13 @@ FusedMHABackwardThunk::GetOrCreateRunner(
return *it->second;
}

absl::Status FusedMHABackwardThunk::Initialize(const InitializeParams& params) {
se::dnn::LazyOpRunner<se::dnn::FusedMHABackwardOp>* lazy_runner =
GetOrCreateRunner(params.stream).AsFusedMHABackwardRunner();
TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHABackwardOpConfig());
return lazy_runner->GetOrCreateRunner(config, params.stream).status();
}

absl::Status FusedMHABackwardThunk::ExecuteOnStream(
const ExecuteParams& params) {
const auto& buffer_allocations = *params.buffer_allocations;
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/runtime/fused_mha_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class FusedMHAThunk : public Thunk {
FusedMHAThunk(const FusedMHAThunk&) = delete;
FusedMHAThunk& operator=(const FusedMHAThunk&) = delete;

absl::Status Initialize(const InitializeParams& params) override;
absl::Status ExecuteOnStream(const ExecuteParams& params) override;

private:
Expand Down Expand Up @@ -101,6 +102,7 @@ class FusedMHABackwardThunk : public Thunk {
FusedMHABackwardThunk(const FusedMHABackwardThunk&) = delete;
FusedMHABackwardThunk& operator=(const FusedMHABackwardThunk&) = delete;

absl::Status Initialize(const InitializeParams& params) override;
absl::Status ExecuteOnStream(const ExecuteParams& params) override;

private:
Expand Down

0 comments on commit c7a2865

Please sign in to comment.