Skip to content

Commit

Permalink
[xla:cpu] Optimize ThunkExecutor::Execute part #1
Browse files Browse the repository at this point in the history
name                                     old cpu/op   new cpu/op   delta
BM_SelectAndScatterF32/128/process_time   889µs ± 1%   740µs ± 3%  -16.70%
BM_SelectAndScatterF32/256/process_time  3.64ms ± 2%  3.00ms ± 1%  -17.64%
BM_SelectAndScatterF32/512/process_time  15.3ms ± 1%  13.1ms ± 3%  -14.61%

PiperOrigin-RevId: 657693426
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Jul 31, 2024
1 parent 7649c3b commit 720b450
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 128 deletions.
205 changes: 144 additions & 61 deletions third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@ limitations under the License.

#include "xla/service/cpu/runtime/kernel_thunk.h"

#include <cstddef>

#define EIGEN_USE_THREADS

#include <atomic>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>

#include "absl/base/attributes.h"
#include "absl/base/optimization.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/numeric/bits.h"
#include "absl/status/status.h"
Expand All @@ -51,50 +51,109 @@ limitations under the License.
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
namespace internal {

absl::StatusOr<std::unique_ptr<KernelThunk>> KernelThunk::Create(
Info info, absl::Span<const BufferAllocation::Slice> arguments_buffers,
// Checks that all buffers are aligned to the minimum alignment. We codegen
// with the assumption that all buffers are aligned, and if they are not, we
// will crash with a segmentation fault, or worse, produce incorrect results.
static absl::Status CheckBufferAlignment(
const Thunk::Info& info, uint64_t min_alignment,
absl::Span<const SE_HOST_KernelArg> kernel_args) {
if (min_alignment == 0) return absl::OkStatus();

for (int64_t i = 0; i < kernel_args.size(); ++i) {
auto ptr = reinterpret_cast<uintptr_t>(kernel_args[i].data);
if (ABSL_PREDICT_FALSE((ptr & (min_alignment - 1)) != 0)) {
return Internal(
"Host kernel %s buffer argument #%d (%p) is not aligned to a "
"required minimum alignment of %d bytes",
info.op_name, i, kernel_args[i].data, min_alignment);
}
}

return absl::OkStatus();
}

// VLOGs kernel arguments resolved from the buffer allocations.
static void VlogKernelArgs(
absl::Span<const BufferAllocation::Slice> arguments_buffers,
absl::Span<const BufferAllocation::Slice> results_buffers,
std::string kernel_name, se::ThreadDim thread_dim,
std::optional<uint64_t> min_alignment) {
if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) {
return Internal("Host kernel %s minimum alignment %d is not a power of 2",
info.op_name, *min_alignment);
absl::Span<const SE_HOST_KernelArg> kernel_args) {
for (int64_t i = 0; i < arguments_buffers.size(); ++i) {
VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", i,
arguments_buffers[i].ToString(),
kernel_args[i].data);
}
for (int64_t i = 0; i < results_buffers.size(); ++i) {
VLOG(3) << absl::StreamFormat(
" res #%d: %s (%p)", i, results_buffers[i].ToString(),
kernel_args[arguments_buffers.size() + i].data);
}
}

return absl::WrapUnique(
new KernelThunk(std::move(info), arguments_buffers, results_buffers,
std::move(kernel_name), thread_dim, min_alignment));
// Returns kernel buffer uses for a given arguments and results buffers.
static Thunk::BufferUses KernelBufferUses(
absl::Span<const BufferAllocation::Slice> arguments_buffers,
absl::Span<const BufferAllocation::Slice> results_buffers) {
Thunk::BufferUses buffer_uses;
for (const BufferAllocation::Slice& buffer : arguments_buffers) {
buffer_uses.emplace_back(buffer, BufferUse::kRead);
}
for (const BufferAllocation::Slice& buffer : results_buffers) {
buffer_uses.emplace_back(buffer, BufferUse::kWrite);
}
return buffer_uses;
}

KernelThunk::KernelThunk(
template <int64_t num_arguments, int64_t num_results>
KernelThunk<num_arguments, num_results>::KernelThunk(
Info info, absl::Span<const BufferAllocation::Slice> arguments_buffers,
absl::Span<const BufferAllocation::Slice> results_buffers,
std::string kernel_name, se::ThreadDim thread_dim,
std::optional<uint64_t> min_alignment)
: Thunk(Kind::kKernel, std::move(info)),
arguments_buffers_(arguments_buffers.begin(), arguments_buffers.end()),
results_buffers_(results_buffers.begin(), results_buffers.end()),
num_kernel_args_(arguments_buffers.size() + results_buffers.size()),
kernel_name_(std::move(kernel_name)),
thread_dim_(thread_dim),
min_alignment_(min_alignment),
call_once_(thread_dim_ == se::ThreadDim()),
kernel_ptr_(nullptr) {
// Resize storage for arguments and results buffers if it is dynamic.
if constexpr (IsDynamic(num_arguments)) {
arguments_buffers_.resize(arguments_buffers.size());
}
if constexpr (IsDynamic(num_results)) {
results_buffers_.resize(results_buffers.size());
}

// Copy buffers from the arguments and results.
for (size_t i = 0; i < arguments_buffers.size(); ++i) {
arguments_buffers_[i] = arguments_buffers[i];
}
for (size_t i = 0; i < results_buffers.size(); ++i) {
results_buffers_[i] = results_buffers[i];
}

// Resize storage for kernel arguments if it is dynamic.
if constexpr (IsDynamic(num_arguments) || IsDynamic(num_results)) {
kernel_args_.resize(num_kernel_args_);
}

// Initialize kernel arguments with null pointers and known buffer sizes.
// We'll use them as a template to resolve buffer addresses at run time.
kernel_args_.reserve(num_kernel_args_);
for (const BufferAllocation::Slice& buffer : arguments_buffers_) {
kernel_args_.emplace_back(
SE_HOST_KernelArg{nullptr, static_cast<size_t>(buffer.size())});
for (size_t i = 0; i < arguments_buffers.size(); ++i) {
kernel_args_[i] = SE_HOST_KernelArg{
nullptr, static_cast<size_t>(arguments_buffers_[i].size())};
}
for (const BufferAllocation::Slice& buffer : results_buffers_) {
kernel_args_.emplace_back(
SE_HOST_KernelArg{nullptr, static_cast<size_t>(buffer.size())});
for (size_t i = 0; i < results_buffers.size(); ++i) {
kernel_args_[arguments_buffers_.size() + i] = SE_HOST_KernelArg{
nullptr, static_cast<size_t>(results_buffers_[i].size())};
}
}

tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
template <int64_t num_arguments, int64_t num_results>
ABSL_ATTRIBUTE_ALWAYS_INLINE tsl::AsyncValueRef<Thunk::ExecuteEvent>
KernelThunk<num_arguments, num_results>::ExecuteInternal(
const ExecuteParams& params) {
tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); });

Expand All @@ -104,7 +163,7 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
kernel_name_, arguments_buffers_.size(), results_buffers_.size(),
thread_dim_.ToString());

absl::InlinedVector<SE_HOST_KernelArg, 8> kernel_args = kernel_args_;
KernelArgs kernel_args = kernel_args_;
SE_HOST_KernelArg* kernel_args_ptr = kernel_args.data();

const BufferAllocations* allocations = params.buffer_allocations;
Expand All @@ -130,12 +189,13 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
}

if (ABSL_PREDICT_FALSE(VLOG_IS_ON(3))) {
VlogKernelArgs(kernel_args);
VlogKernelArgs(arguments_buffers_, results_buffers_, kernel_args);
}

// Сheck that all resolved buffers are properly aligned.
if constexpr (ShouldCheckBufferSlices()) {
TF_RETURN_IF_ERROR(CheckBufferAlignment(kernel_args));
TF_RETURN_IF_ERROR(
CheckBufferAlignment(info(), min_alignment_.value_or(0), kernel_args));
}

// TODO(ezhulenev): Kernel ptr should be loaded as a part of Thunk
Expand Down Expand Up @@ -173,45 +233,68 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
return OkExecuteEvent();
}

absl::Status KernelThunk::CheckBufferAlignment(
absl::Span<const SE_HOST_KernelArg> kernel_args) {
if (min_alignment_.has_value()) {
for (int64_t i = 0; i < num_kernel_args_; ++i) {
auto ptr = reinterpret_cast<uintptr_t>(kernel_args[i].data);
if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) {
return Internal(
"Host kernel %s buffer argument #%d (%p) is not aligned to a "
"required minimum alignment of %d bytes",
info().op_name, i, kernel_args[i].data, *min_alignment_);
}
}
}
return absl::OkStatus();
template <int64_t num_arguments, int64_t num_results>
KernelThunk<num_arguments, num_results>::BufferUses
KernelThunk<num_arguments, num_results>::buffer_uses() const {
return KernelBufferUses(arguments_buffers_, results_buffers_);
}

void KernelThunk::VlogKernelArgs(
absl::Span<const SE_HOST_KernelArg> kernel_args) {
for (int64_t i = 0; i < arguments_buffers_.size(); ++i) {
VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", i,
arguments_buffers_[i].ToString(),
kernel_args[i].data);
}
for (int64_t i = 0; i < results_buffers_.size(); ++i) {
VLOG(3) << absl::StreamFormat(
" res #%d: %s (%p)", i, results_buffers_[i].ToString(),
kernel_args[arguments_buffers_.size() + i].data);
}
} // namespace internal

tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
const Thunk::ExecuteParams& params) {
return Base::ExecuteInternal(params);
}

KernelThunk::BufferUses KernelThunk::buffer_uses() const {
BufferUses buffer_uses;
for (const BufferAllocation::Slice& buffer : arguments_buffers_) {
buffer_uses.emplace_back(buffer, BufferUse::kRead);
}
for (const BufferAllocation::Slice& buffer : results_buffers_) {
buffer_uses.emplace_back(buffer, BufferUse::kWrite);
template <int64_t num_arguments, int64_t num_results>
tsl::AsyncValueRef<Thunk::ExecuteEvent>
SmallKernelThunk<num_arguments, num_results>::Execute(
const Thunk::ExecuteParams& params) {
return Base::ExecuteInternal(params);
}

absl::StatusOr<std::unique_ptr<Thunk>> KernelThunk::Create(
Thunk::Info info,
absl::Span<const BufferAllocation::Slice> arguments_buffers,
absl::Span<const BufferAllocation::Slice> results_buffers,
std::string kernel_name, se::ThreadDim thread_dim,
std::optional<uint64_t> min_alignment) {
if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) {
return Internal("Host kernel %s minimum alignment %d is not a power of 2",
info.op_name, *min_alignment);
}
return buffer_uses;

auto make_small_kernel_thunk = [&](auto num_arguments, auto num_results) {
return absl::WrapUnique(
new SmallKernelThunk<num_arguments(), num_results()>(
std::move(info), arguments_buffers, results_buffers,
std::move(kernel_name), thread_dim, min_alignment));
};

static constexpr auto _0 = std::integral_constant<size_t, 0>{};
static constexpr auto _1 = std::integral_constant<size_t, 1>{};
static constexpr auto _2 = std::integral_constant<size_t, 2>{};
static constexpr auto _3 = std::integral_constant<size_t, 3>{};
static constexpr auto _4 = std::integral_constant<size_t, 4>{};
static constexpr auto _5 = std::integral_constant<size_t, 5>{};
static constexpr auto _6 = std::integral_constant<size_t, 6>{};

std::pair<size_t, size_t> params(arguments_buffers.size(),
results_buffers.size());

// Return SmallKernelThunk specializations for the most common cases.
if (params == std::make_pair(_0, _1)) return make_small_kernel_thunk(_0, _1);
if (params == std::make_pair(_1, _1)) return make_small_kernel_thunk(_1, _1);
if (params == std::make_pair(_2, _1)) return make_small_kernel_thunk(_2, _1);
if (params == std::make_pair(_3, _1)) return make_small_kernel_thunk(_3, _1);
if (params == std::make_pair(_4, _1)) return make_small_kernel_thunk(_4, _1);
if (params == std::make_pair(_5, _1)) return make_small_kernel_thunk(_5, _1);
if (params == std::make_pair(_6, _1)) return make_small_kernel_thunk(_6, _1);

// Return a generic KernelThunk for dynamic numbers of arguments and results.
return absl::WrapUnique(
new KernelThunk(std::move(info), arguments_buffers, results_buffers,
std::move(kernel_name), thread_dim, min_alignment));
}

} // namespace xla::cpu
Loading

0 comments on commit 720b450

Please sign in to comment.