Skip to content

Commit

Permalink
Add TuningContext
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Feb 3, 2023
1 parent 1ae9fe6 commit fedda8f
Show file tree
Hide file tree
Showing 52 changed files with 644 additions and 248 deletions.
8 changes: 8 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Node;
#include "core/framework/func_api.h"
#include "core/framework/provider_options.h"
#include "core/framework/stream_handles.h"
#include "core/framework/tuning_context.h"

namespace onnxruntime {

Expand Down Expand Up @@ -300,6 +301,13 @@ class IExecutionProvider {
*/
virtual bool ConcurrentRunSupported() const { return true; }

/**
* Return the tuning context which holds all TunableOp state.
*/
virtual ITuningContext* GetTuningContext() const {
return nullptr;
}

private:
const std::string type_;

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
return LaunchDecoderAttentionKernel(
device_prop,
#ifdef USE_ROCM
IsTunableOpEnabled(),
GetTuningContext(),
#endif
stream,
cublas,
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/rocm/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
// TODO: use custom kernel of expand to improve the performance.
ORT_RETURN_IF_ERROR(blas::column_major::Gemm(
IsTunableOpEnabled(), Stream(context), rocblas,
GetTuningContext(), Stream(context), rocblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
n, m, 1,
/*alpha=*/1.0f,
Expand All @@ -99,7 +99,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {

// result(N, M) = 1 * weights x input + 1 x B.
ORT_RETURN_IF_ERROR(blas::column_major::Gemm(
IsTunableOpEnabled(), Stream(context), rocblas,
GetTuningContext(), Stream(context), rocblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
n, m, k,
/*alpha=*/1.0f,
Expand All @@ -114,7 +114,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
return LaunchAttentionKernel(
device_prop,
IsTunableOpEnabled(),
GetTuningContext(),
Stream(context),
rocblas,
element_size,
Expand Down
28 changes: 14 additions & 14 deletions onnxruntime/contrib_ops/rocm/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ size_t GetAttentionWorkspaceSize(
template <typename T>
Status QkvToContext(
const hipDeviceProp_t& prop,
bool tuning,
RocmTuningContext* tuning_ctx,
rocblas_handle& rocblas,
hipStream_t stream,
const int batch_size,
Expand Down Expand Up @@ -139,7 +139,7 @@ Status QkvToContext(
const int temp_matrix_size = sequence_length * all_sequence_length;

ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::Trans, blas::BlasOp::NonTrans,
all_sequence_length, sequence_length, head_size,
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
Expand Down Expand Up @@ -174,7 +174,7 @@ Status QkvToContext(

// compute P*V (as V*P), and store in scratch3: BxNxSxH
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
head_size, sequence_length, all_sequence_length,
/*alpha=*/1.0f,
Expand All @@ -191,7 +191,7 @@ Status QkvToContext(

Status LaunchAttentionKernel(
const hipDeviceProp_t& prop,
bool tuning,
RocmTuningContext* tuning_ctx,
hipStream_t stream,
rocblas_handle& rocblas,
const size_t element_size,
Expand All @@ -215,7 +215,7 @@ Status LaunchAttentionKernel(
bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax();
if (element_size == 2) {
return QkvToContext(
prop, tuning, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
prop, tuning_ctx, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
reinterpret_cast<const __half*>(input),
reinterpret_cast<__half*>(output),
reinterpret_cast<__half*>(workspace),
Expand All @@ -230,7 +230,7 @@ Status LaunchAttentionKernel(
use_persistent_softmax);
} else {
return QkvToContext(
prop, tuning, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
prop, tuning_ctx, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
reinterpret_cast<const float*>(input),
reinterpret_cast<float*>(output),
reinterpret_cast<float*>(workspace),
Expand All @@ -249,7 +249,7 @@ Status LaunchAttentionKernel(
template <typename T>
Status DecoderQkvToContext(
const hipDeviceProp_t& prop,
bool tuning,
RocmTuningContext* tuning_ctx,
hipStream_t stream,
rocblas_handle& rocblas,
const size_t element_size,
Expand Down Expand Up @@ -352,7 +352,7 @@ Status DecoderQkvToContext(
const int strideB = sequence_length * head_size;
if (use_past && static_kv) {
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::Trans, blas::BlasOp::NonTrans,
kv_sequence_length, sequence_length, head_size,
/*alpha=*/rsqrt_head_size,
Expand All @@ -363,7 +363,7 @@ Status DecoderQkvToContext(
BN));
} else {
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::Trans, blas::BlasOp::NonTrans,
kv_sequence_length, sequence_length, head_size,
/*alpha=*/rsqrt_head_size,
Expand All @@ -386,7 +386,7 @@ Status DecoderQkvToContext(
// compute P*V (as V*P), and store in scratch3: BxNxSxH
if (use_past && static_kv) {
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
head_size, sequence_length, kv_sequence_length,
/*alpha=*/1.0f,
Expand All @@ -397,7 +397,7 @@ Status DecoderQkvToContext(
BN));
} else {
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning, stream, rocblas,
tuning_ctx, stream, rocblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
head_size, sequence_length, kv_sequence_length,
/*alpha=*/1.0f,
Expand All @@ -415,7 +415,7 @@ Status DecoderQkvToContext(

Status LaunchDecoderAttentionKernel(
const hipDeviceProp_t& prop,
bool tuning,
RocmTuningContext* tuning_ctx,
hipStream_t stream,
rocblas_handle& rocblas,
const size_t element_size,
Expand All @@ -442,7 +442,7 @@ Status LaunchDecoderAttentionKernel(
if (element_size == 2) {
return DecoderQkvToContext(
prop,
tuning,
tuning_ctx,
stream,
rocblas,
element_size,
Expand All @@ -469,7 +469,7 @@ Status LaunchDecoderAttentionKernel(
} else {
return DecoderQkvToContext(
prop,
tuning,
tuning_ctx,
stream,
rocblas,
element_size,
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/rocm/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <hip/hip_fp16.h>
#include <rocblas/rocblas.h>
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -27,7 +28,7 @@ size_t GetAttentionWorkspaceSize(

Status LaunchAttentionKernel(
const hipDeviceProp_t& prop, // Device Properties
bool tuning, // Whether to enable tuning
RocmTuningContext* tuning_ctx, // context for tuning
hipStream_t stream, // Hip stream
rocblas_handle& rocblas, // Rocblas handle
const size_t element_size, // Element size of input tensor
Expand All @@ -50,7 +51,7 @@ Status LaunchAttentionKernel(

Status LaunchDecoderAttentionKernel(
const hipDeviceProp_t& prop, // Device Properties
bool tuning, // Whether to enable tuning
RocmTuningContext* tuning_ctx, // context for tuning
hipStream_t stream, // Hip stream
rocblas_handle& rocblas, // Rocblas handle
const size_t element_size, // Element size of input tensor
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
typedef typename ToHipType<T>::MappedType HipT;

return LaunchFastGeluKernel<HipT>(IsTunableOpEnabled(),
return LaunchFastGeluKernel<HipT>(GetTuningContext(),
Stream(context),
static_cast<int>(input_length),
static_cast<int>(bias_length),
Expand Down
16 changes: 9 additions & 7 deletions onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,27 @@ namespace contrib {
namespace rocm {

template <typename T>
Status LaunchFastGeluKernel(bool tuning, hipStream_t stream, int input_length, int bias_length,
Status LaunchFastGeluKernel(RocmTuningContext* tuning_ctx, hipStream_t stream, int input_length, int bias_length,
const T* input, const T* bias, T* output) {
FastGeluParams<T> params(stream, input, bias, output, input_length, bias_length);
if (tuning) {
FastGeluParams<T> params(tuning_ctx, stream, input, bias, output, input_length, bias_length);
if (tuning_ctx->IsTunableOpEnabled()) {
static FastGeluTunableOp<T> op;
op.EnableTuning();
return op(&params);
}

return FastGeluStaticSelection<T>(&params);
}

template Status LaunchFastGeluKernel<float>(bool tuning, hipStream_t stream, int input_length, int bias_length,
template Status LaunchFastGeluKernel<float>(RocmTuningContext* tuning_ctx, hipStream_t stream,
int input_length, int bias_length,
const float* input, const float* bias, float* output);

template Status LaunchFastGeluKernel<BFloat16>(bool tuning, hipStream_t stream, int input_length, int bias_length,
template Status LaunchFastGeluKernel<BFloat16>(RocmTuningContext* tuning_ctx, hipStream_t stream,
int input_length, int bias_length,
const BFloat16* input, const BFloat16* bias, BFloat16* output);

template Status LaunchFastGeluKernel<half>(bool tuning, hipStream_t stream, int input_length, int bias_length,
template Status LaunchFastGeluKernel<half>(RocmTuningContext* tuning_ctx, hipStream_t stream,
int input_length, int bias_length,
const half* input, const half* bias, half* output);

} // namespace rocm
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"

namespace onnxruntime {
namespace contrib {
namespace rocm {

template <typename T>
Status LaunchFastGeluKernel(bool tuning, hipStream_t stream, int input_length, int bias_length,
Status LaunchFastGeluKernel(RocmTuningContext* tuning_ctx, hipStream_t stream, int input_length, int bias_length,
const T* input, const T* bias, T* output);

} // namespace rocm
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/rocm/bert/fast_gelu_tunable_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ namespace rocm {

template <typename T>
struct FastGeluParams : OpParams {
FastGeluParams(hipStream_t stream, const T* input, const T* bias, T* output, int input_length, int bias_length) :
OpParams(stream), input(input), bias(bias), output(output), input_length(input_length), bias_length(bias_length) {}
FastGeluParams(RocmTuningContext* tuning_ctx, hipStream_t stream, const T* input, const T* bias, T* output, int input_length, int bias_length) :
OpParams(tuning_ctx, stream), input(input), bias(bias), output(output), input_length(input_length), bias_length(bias_length) {}

std::string Signature() const override {
std::string sig = std::to_string(input_length) + "_" + std::to_string(bias_length);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Status GemmFastGelu<T>::ComputeInternal(OpKernelContext* ctx) const {
using onnxruntime::rocm::tunable::blas::BlasOp;

return blas::row_major::GemmFastGelu(
IsTunableOpEnabled(),
GetTuningContext(),
Stream(ctx), GetRocblasHandle(ctx),
transa ? BlasOp::Trans : BlasOp::NonTrans,
transb ? BlasOp::Trans : BlasOp::NonTrans,
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ struct GemmFastGeluParams : OpParams {
T beta;
T* c;
int64_t ldc;
bool tuning{false};
};

} // namespace blas
Expand Down
10 changes: 3 additions & 7 deletions onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace row_major {
template <typename T, typename ScalarT>
inline GEMMFASTGELU(T, ScalarT) {
GemmFastGeluParams<T> params;
params.tuning_ctx = tuning_ctx;
params.stream = stream;
params.handle = handle;

Expand All @@ -46,23 +47,18 @@ inline GEMMFASTGELU(T, ScalarT) {
params.c = c;
params.ldc = ldc;

if (tunable) {
params.tuning = true;
if (tuning_ctx->IsTunableOpEnabled()) {
if (opa == BlasOp::N && opb == BlasOp::N) {
static internal::GemmFastGeluTunableOp<T, internal::Row, internal::Row> gemm_fast_gelu{};
gemm_fast_gelu.EnableTuning();
return gemm_fast_gelu(&params);
} else if (opa == BlasOp::T && opb == BlasOp::N) {
static internal::GemmFastGeluTunableOp<T, internal::Col, internal::Row> gemm_fast_gelu{};
gemm_fast_gelu.EnableTuning();
return gemm_fast_gelu(&params);
} else if (opa == BlasOp::N && opb == BlasOp::T) {
static internal::GemmFastGeluTunableOp<T, internal::Row, internal::Col> gemm_fast_gelu{};
gemm_fast_gelu.EnableTuning();
return gemm_fast_gelu(&params);
} else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ {
static internal::GemmFastGeluTunableOp<T, internal::Col, internal::Col> gemm_fast_gelu{};
gemm_fast_gelu.EnableTuning();
return gemm_fast_gelu(&params);
}
}
Expand All @@ -71,7 +67,7 @@ inline GEMMFASTGELU(T, ScalarT) {
}

#define CALL_GEMMFASTGELU(T, ScalarT) \
GemmFastGelu<T, ScalarT>(tunable, stream, handle, \
GemmFastGelu<T, ScalarT>(tuning_ctx, stream, handle, \
opa, opb, \
m, n, k, \
alpha, a, lda, b, ldb, bias, \
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace blas {

#define GEMMFASTGELU(T, ScalarT) \
common::Status GemmFastGelu( \
bool tunable, hipStream_t stream, rocblas_handle handle, \
RocmTuningContext* tuning_ctx, hipStream_t stream, rocblas_handle handle, \
BlasOp opa, BlasOp opb, \
std::int64_t m, std::int64_t n, std::int64_t k, \
ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace internal {
template <typename T>
Status GemmFastGeluUnfused(const GemmFastGeluParams<T>* params) {
namespace column_major = onnxruntime::rocm::tunable::blas::column_major;
ORT_RETURN_IF_ERROR(column_major::Gemm(params->tuning, params->stream, params->handle,
ORT_RETURN_IF_ERROR(column_major::Gemm(params->tuning_ctx, params->stream, params->handle,
params->opb, params->opa,
params->n, params->m, params->k,
params->alpha, params->b, params->ldb, params->a, params->lda,
Expand All @@ -41,7 +41,7 @@ Status GemmFastGeluUnfused(const GemmFastGeluParams<T>* params) {
//
// Note: If any change cause directly usage of GemmFastGeluUnfused, add PreTuning() and PostTuning() in FastGeluTunableOp
// to protect original input value.
return onnxruntime::contrib::rocm::LaunchFastGeluKernel<T>(params->tuning,
return onnxruntime::contrib::rocm::LaunchFastGeluKernel<T>(params->tuning_ctx,
params->stream,
static_cast<int>(fast_gelu_input_length),
static_cast<int>(bias_length),
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
typedef typename ToHipType<T>::MappedType HipT;

return LaunchSkipLayerNormKernel<HipT>(
IsTunableOpEnabled(),
GetTuningContext(),
Stream(ctx),
reinterpret_cast<HipT*>(output->MutableData<T>()),
reinterpret_cast<const HipT*>(input->Data<T>()),
Expand Down
Loading

0 comments on commit fedda8f

Please sign in to comment.