Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change the return type of softmax function to Status #14559

Merged
merged 5 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions onnxruntime/contrib_ops/cuda/bert/attention_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -714,12 +714,12 @@ Status ComputeSoftmaxWithRawMask(cudaStream_t stream,
}

if (use_persistent_softmax) {
dispatch_warpwise_softmax_forward<T, T, float, false>(stream,
output,
persistent_softmax_workspace,
all_sequence_length,
all_sequence_length,
batch_size * num_heads * sequence_length);
return dispatch_warpwise_softmax_forward<T, T, float, false>(stream,
output,
persistent_softmax_workspace,
all_sequence_length,
all_sequence_length,
batch_size * num_heads * sequence_length);
}

return CUDA_CALL(cudaGetLastError());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,11 @@ Status ProcessLogits(const OrtValue& logits, //

const CudaT* X_data = is_reuse_logits_buffer ? logits_data : reinterpret_cast<const CudaT*>(next_token_logits.data());

dispatch_blockwise_softmax_forward<CudaT, float, float, true>(
ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward<CudaT, float, float, true>(
cuda_stream, Y_data, X_data, vocab_size,
is_reuse_logits_buffer ? padded_vocab_size : vocab_size,
vocab_size,
batch_size * num_beams);
batch_size * num_beams)));

#ifdef DEBUG_GENERATION
dumper->Print("next_token_scores after softmax", next_token_scores.data(), batch_size, num_beams, vocab_size);
Expand Down
30 changes: 15 additions & 15 deletions onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ Status Sample(AllocatorPtr& allocator,
#endif

gsl::span<float>& d_sorted_softmaxed_score = sampling_state->d_sorted_softmaxed_score;
dispatch_blockwise_softmax_forward<CudaT, float, float, false>(cuda_stream,
d_sorted_softmaxed_score.data(),
reinterpret_cast<CudaT*>(d_sorted_score.data()),
parameters->vocab_size,
parameters->vocab_size,
parameters->vocab_size,
parameters->batch_size);

ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward<CudaT, float, float, false>(cuda_stream,
d_sorted_softmaxed_score.data(),
reinterpret_cast<CudaT*>(d_sorted_score.data()),
parameters->vocab_size,
parameters->vocab_size,
parameters->vocab_size,
parameters->batch_size)));
#ifdef DEBUG_GENERATION
dumper->Print("d_sorted_softmaxed_score_buffer",
d_sorted_softmaxed_score.data(),
Expand All @@ -122,13 +122,13 @@ Status Sample(AllocatorPtr& allocator,
#endif

gsl::span<float>& d_softmaxed_score = sampling_state->d_softmaxed_score;
dispatch_blockwise_softmax_forward<CudaT, float, float, false>(cuda_stream,
d_softmaxed_score.data(),
reinterpret_cast<CudaT*>(next_token_scores.data()),
parameters->vocab_size,
parameters->vocab_size,
parameters->vocab_size,
parameters->batch_size);
ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward<CudaT, float, float, false>(cuda_stream,
d_softmaxed_score.data(),
reinterpret_cast<CudaT*>(next_token_scores.data()),
parameters->vocab_size,
parameters->vocab_size,
parameters->vocab_size,
parameters->batch_size)));

#ifdef DEBUG_GENERATION
dumper->Print("d_softmaxed_score_buffer",
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/contrib_ops/rocm/bert/attention_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,12 +513,12 @@ Status ComputeSoftmaxWithRawMask(hipStream_t stream,
}

if (use_persistent_softmax) {
dispatch_warpwise_softmax_forward<T, T, float, false>(stream,
output,
persistent_softmax_workspace,
all_sequence_length,
all_sequence_length,
batch_size * num_heads * sequence_length);
return dispatch_warpwise_softmax_forward<T, T, float, false>(stream,
output,
persistent_softmax_workspace,
all_sequence_length,
all_sequence_length,
batch_size * num_heads * sequence_length);
}

return HIP_CALL(hipPeekAtLastError());
Expand Down
11 changes: 4 additions & 7 deletions onnxruntime/core/providers/cuda/math/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,12 @@ Status SoftMaxComputeHelper(
auto X_data = reinterpret_cast<const CudaT*>(X);

if (D <= 1024 && D * sizeof(T) <= 4096) {
dispatch_warpwise_softmax_forward<CudaT, CudaT, AccumulationType_t<CudaT>, is_log_softmax>(
return dispatch_warpwise_softmax_forward<CudaT, CudaT, AccumulationType_t<CudaT>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N));
} else {
dispatch_blockwise_softmax_forward<CudaT, CudaT, AccumulationType_t<CudaT>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D),
gsl::narrow_cast<int>(N));
}

return Status::OK();
return dispatch_blockwise_softmax_forward<CudaT, CudaT, AccumulationType_t<CudaT>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D),
gsl::narrow_cast<int>(N));
}

#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T) \
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/cuda/math/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ Status SoftMaxComputeHelper(
int64_t axis);

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const input_t* src,
int softmax_elements, int softmax_elements_stride, int batch_count);
Status dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const input_t* src,
int softmax_elements, int softmax_elements_stride, int batch_count);

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, const input_t* input,
int softmax_elements, int input_stride, int output_stride, int batch_count);
Status dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, const input_t* input,
int softmax_elements, int input_stride, int output_stride, int batch_count);

template <typename T>
class Softmax final : public CudaKernel {
Expand Down
28 changes: 15 additions & 13 deletions onnxruntime/core/providers/cuda/math/softmax_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ namespace onnxruntime {
namespace cuda {

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
Status dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
if (softmax_elements == 0) {
return;
return Status::OK();
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
Expand Down Expand Up @@ -99,24 +99,25 @@ void dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const
break;
}
}
return CUDA_CALL(cudaGetLastError());
}

#define SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template void dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, false>(cudaStream_t stream, output_t * dst, \
const input_t* src, int softmax_elements, \
int softmax_elements_stride, int batch_count); \
template void dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, true>(cudaStream_t stream, output_t * dst, \
const input_t* src, int softmax_elements, \
int softmax_elements_stride, int batch_count);
#define SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, false>(cudaStream_t stream, output_t * dst, \
const input_t* src, int softmax_elements, \
int softmax_elements_stride, int batch_count); \
template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, true>(cudaStream_t stream, output_t * dst, \
const input_t* src, int softmax_elements, \
int softmax_elements_stride, int batch_count);

SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(float, float, float)
SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(half, half, float)
SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(double, double, double)
SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(BFloat16, BFloat16, float)

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, const input_t* input, int softmax_elements,
int input_stride, int output_stride, int batch_count) {
Status dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, const input_t* input, int softmax_elements,
int input_stride, int output_stride, int batch_count) {
dim3 grid(batch_count);
constexpr int ILP = sizeof(float4) / sizeof(input_t);
dim3 block = SoftMax_getBlockSize(ILP, softmax_elements);
Expand All @@ -129,13 +130,14 @@ void dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, c
<<<grid, block, block.x * sizeof(acc_t), stream>>>(output, const_cast<input_t*>(input),
softmax_elements, input_stride, output_stride);
}
return CUDA_CALL(cudaGetLastError());
}

#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template void dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, false>( \
template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, false>( \
cudaStream_t stream, output_t * output, const input_t* src, int softmax_elements, \
int input_stride, int output_stride, int batch_count); \
template void dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, true>( \
template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, true>( \
cudaStream_t stream, output_t * output, const input_t* src, int softmax_elements, \
int input_stride, int output_stride, int batch_count);

Expand Down
11 changes: 4 additions & 7 deletions onnxruntime/core/providers/rocm/math/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,12 @@ Status SoftMaxComputeHelper(
auto X_data = reinterpret_cast<const HipT*>(X);

if (D <= 1024 && D * sizeof(T) <= 4096) {
dispatch_warpwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>(
return dispatch_warpwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N));
} else {
dispatch_blockwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D),
gsl::narrow_cast<int>(N));
}

return Status::OK();
return dispatch_blockwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D),
gsl::narrow_cast<int>(N));
}

#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T) \
Expand Down
24 changes: 13 additions & 11 deletions onnxruntime/core/providers/rocm/math/softmax_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ namespace onnxruntime {
namespace rocm {

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
if (softmax_elements == 0) {
return;
return Status::OK();
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
Expand Down Expand Up @@ -88,19 +88,20 @@ void dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const
break;
}
}
return HIP_CALL(hipGetLastError());
}

#define SPECIALIZED_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template void dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, false>(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); \
template void dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, true>(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count);
template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, false>(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); \
template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, true>(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count);

SPECIALIZED_SOFTMAX_IMPL(float, float, float)
SPECIALIZED_SOFTMAX_IMPL(half, half, float)
SPECIALIZED_SOFTMAX_IMPL(double, double, double)
SPECIALIZED_SOFTMAX_IMPL(BFloat16, BFloat16, float)

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, const input_t* input, int softmax_elements,
Status dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, const input_t* input, int softmax_elements,
int input_stride, int output_stride, int batch_count) {
dim3 grid(batch_count);
constexpr int ILP = sizeof(float4) / sizeof(input_t);
Expand All @@ -114,14 +115,15 @@ void dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, co
<<<grid, block, block.x * sizeof(acc_t), stream>>>(output, const_cast<input_t*>(input),
softmax_elements, input_stride, output_stride);
}
return HIP_CALL(hipGetLastError());
}

#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template void dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, false>( \
hipStream_t stream, output_t * output, const input_t* src, int softmax_elements, \
int input_stride, int output_stride, int batch_count); \
template void dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, true>( \
hipStream_t stream, output_t * output, const input_t* src, int softmax_elements, \
#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, false>( \
hipStream_t stream, output_t * output, const input_t* src, int softmax_elements, \
int input_stride, int output_stride, int batch_count); \
template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, true>( \
hipStream_t stream, output_t * output, const input_t* src, int softmax_elements, \
int input_stride, int output_stride, int batch_count);

SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float)
Expand Down