diff --git a/xla/service/gpu/runtime/BUILD b/xla/service/gpu/runtime/BUILD index f21a433dd23e4..c989a2c1a6608 100644 --- a/xla/service/gpu/runtime/BUILD +++ b/xla/service/gpu/runtime/BUILD @@ -93,6 +93,7 @@ cc_library( "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:trace_command_buffer_factory", "//xla/stream_executor/gpu:gpu_blas_lt", + "//xla/runtime:buffer_use", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -138,6 +139,7 @@ cc_library( ":wait_for_streams_thunk", ":while_thunk", "//xla:util", + "//xla/runtime:buffer_use", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -160,6 +162,7 @@ xla_test( "//xla/service:platform_util", "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:launch_dimensions", + "//xla/runtime:buffer_use", "//xla/stream_executor:command_buffer", "//xla/stream_executor:device_memory", "//xla/stream_executor:platform", @@ -354,6 +357,7 @@ xla_test( "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_utils", + "//xla/runtime:buffer_use", "//xla/stream_executor:blas", "//xla/stream_executor:command_buffer", "//xla/stream_executor:device_description", diff --git a/xla/service/gpu/runtime/command_buffer_cmd.cc b/xla/service/gpu/runtime/command_buffer_cmd.cc index ba0f168e40efc..af5b42f5066a3 100644 --- a/xla/service/gpu/runtime/command_buffer_cmd.cc +++ b/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -89,7 +89,7 @@ limitations under the License. namespace xla::gpu { using ExecutionScopeId = se::CommandBuffer::ExecutionScopeId; -using MemoryAccess = CommandBufferCmd::MemoryAccess; +using MemoryAccess = BufferUse::MemoryAccess; std::string CommandBufferCmdString(CommandBufferCmdType type) { switch (type) { @@ -195,13 +195,13 @@ CommandBufferCmdSequence::CommandBufferCmdSequence( : synchronization_mode_(synchronization_mode) {} void CommandBufferCmdSequence::Append(std::unique_ptr cmd) { - for (const CommandBufferCmd::BufferUsage& buffer : cmd->buffers()) { + for (const BufferUse& buffer : cmd->buffers()) { buffers_.insert(buffer); - allocs_indices_.insert(buffer.slice.index()); + allocs_indices_.insert(buffer.slice().index()); } ExecutionStreamId execution_stream_id = cmd->execution_stream_id(); - CommandBufferCmd::BufferUsageVector buffers = cmd->buffers(); + CommandBufferCmd::BufferUseVector buffers = cmd->buffers(); bool requires_barrier = HasConflicts(execution_stream_id, buffers); // Always add barriers between commands if we want to serialize execution. @@ -254,24 +254,26 @@ bool Overlaps(const BufferAllocation::Slice& slice, bool CommandBufferCmdSequence::HasConflicts( ExecutionStreamId execution_stream_id, - const CommandBufferCmd::BufferUsageVector& buffers) { + const CommandBufferCmd::BufferUseVector& buffers) { auto& rwset = read_write_sets_[execution_stream_id]; return absl::c_any_of(buffers, [&](const auto& buffer) { - return buffer.access == MemoryAccess::kWrite - ? Overlaps(buffer.slice, rwset.write) || - Overlaps(buffer.slice, rwset.read) - : Overlaps(buffer.slice, rwset.write); + return buffer.access() == MemoryAccess::kWrite + ? Overlaps(buffer.slice(), rwset.write) || + Overlaps(buffer.slice(), rwset.read) + : Overlaps(buffer.slice(), rwset.write); }); } void CommandBufferCmdSequence::TrackBuffers( ExecutionStreamId execution_stream_id, - const CommandBufferCmd::BufferUsageVector& buffers) { + const CommandBufferCmd::BufferUseVector& buffers) { auto& rwset = read_write_sets_[execution_stream_id]; - for (const CommandBufferCmd::BufferUsage& buffer : buffers) { - if (buffer.access == MemoryAccess::kWrite) rwset.write.insert(buffer.slice); - if (buffer.access == MemoryAccess::kRead) rwset.read.insert(buffer.slice); + for (const BufferUse& buffer : buffers) { + if (buffer.access() == MemoryAccess::kWrite) + rwset.write.insert(buffer.slice()); + if (buffer.access() == MemoryAccess::kRead) + rwset.read.insert(buffer.slice()); } } @@ -346,8 +348,8 @@ absl::Status CommandBufferCmdSequence::Record( return absl::OkStatus(); } -const absl::flat_hash_set& -CommandBufferCmdSequence::buffers() const { +const absl::flat_hash_set& CommandBufferCmdSequence::buffers() + const { return buffers_; } @@ -369,13 +371,13 @@ std::vector CommandBufferCmdSequence::barriers() const { TracedCommandBuffer::TracedCommandBuffer( const CommandBufferCmd* trace_cmd, - CommandBufferCmd::BufferUsageVector buffers, int64_t capacity) + CommandBufferCmd::BufferUseVector buffers, int64_t capacity) : trace_cmd_(trace_cmd), capacity_(capacity), entries_(capacity) { CHECK_GT(capacity, 0) << "capacity must be larger than 0"; // NOLINT // Collect unique buffer allocation indices in a set first and convert to // vector as flat hash set iteration has measurable overheads. absl::flat_hash_set allocs_indices; - for (auto& buffer : buffers) allocs_indices.insert(buffer.slice.index()); + for (auto& buffer : buffers) allocs_indices.insert(buffer.slice().index()); allocs_indices_.assign(allocs_indices.begin(), allocs_indices.end()); } @@ -535,7 +537,7 @@ ComputationIdCmd::ComputationIdCmd(ExecutionStreamId execution_stream_id, dest_(dest), kind_(kind) {} -CommandBufferCmd::BufferUsageVector ComputationIdCmd::buffers() { +CommandBufferCmd::BufferUseVector ComputationIdCmd::buffers() { return {{dest_, MemoryAccess::kWrite}}; } @@ -674,8 +676,8 @@ absl::Status LaunchCmd::Record(const Thunk::ExecuteParams& execute_params, dims_.block_counts(), *kernel, *kernel_args); } -CommandBufferCmd::BufferUsageVector LaunchCmd::buffers() { - BufferUsageVector buffers; +CommandBufferCmd::BufferUseVector LaunchCmd::buffers() { + BufferUseVector buffers; for (int32_t i = 0; i < args_.size(); ++i) { buffers.emplace_back(args_[i], args_access_[i]); } @@ -746,8 +748,8 @@ absl::Status CustomKernelLaunchCmd::Record( custom_kernel_.block_dims(), *kernel, kernel_args); } -CommandBufferCmd::BufferUsageVector CustomKernelLaunchCmd::buffers() { - BufferUsageVector buffers; +CommandBufferCmd::BufferUseVector CustomKernelLaunchCmd::buffers() { + BufferUseVector buffers; for (int32_t i = 0; i < args_.size(); ++i) { buffers.emplace_back(args_[i], args_access_[i]); } @@ -790,7 +792,7 @@ absl::Status MemcpyDeviceToDeviceCmd::Record( num_bytes_); } -CommandBufferCmd::BufferUsageVector MemcpyDeviceToDeviceCmd::buffers() { +CommandBufferCmd::BufferUseVector MemcpyDeviceToDeviceCmd::buffers() { return {{dst_, MemoryAccess::kWrite}, {src_, MemoryAccess::kRead}}; } @@ -822,7 +824,7 @@ absl::Status MemzeroCmd::Record(const Thunk::ExecuteParams& execute_params, /*num_elements=*/dst_.size()); } -CommandBufferCmd::BufferUsageVector MemzeroCmd::buffers() { +CommandBufferCmd::BufferUseVector MemzeroCmd::buffers() { return {{dst_, MemoryAccess::kWrite}}; } @@ -857,7 +859,7 @@ absl::Status Memset32Cmd::Record(const Thunk::ExecuteParams& execute_params, /*num_elements=*/dst_.size() / sizeof(uint32_t)); } -CommandBufferCmd::BufferUsageVector Memset32Cmd::buffers() { +CommandBufferCmd::BufferUseVector Memset32Cmd::buffers() { return {{dst_, MemoryAccess::kWrite}}; } @@ -894,8 +896,8 @@ absl::Status IfCmd::Record(const Thunk::ExecuteParams& execute_params, bool IfCmd::force_update() { return then_commands_.force_update(); } -CommandBufferCmd::BufferUsageVector IfCmd::buffers() { - absl::flat_hash_set buffers; +CommandBufferCmd::BufferUseVector IfCmd::buffers() { + absl::flat_hash_set buffers; buffers.emplace(pred_, MemoryAccess::kRead); buffers.insert(then_commands_.buffers().begin(), then_commands_.buffers().end()); @@ -942,8 +944,8 @@ bool IfElseCmd::force_update() { return (then_commands_.force_update() || else_commands_.force_update()); } -CommandBufferCmd::BufferUsageVector IfElseCmd::buffers() { - absl::flat_hash_set buffers; +CommandBufferCmd::BufferUseVector IfElseCmd::buffers() { + absl::flat_hash_set buffers; buffers.emplace(pred_, MemoryAccess::kRead); buffers.insert(then_commands_.buffers().begin(), then_commands_.buffers().end()); @@ -992,8 +994,8 @@ bool CaseCmd::force_update() { [](const auto& seq) { return seq.force_update(); }); } -CommandBufferCmd::BufferUsageVector CaseCmd::buffers() { - absl::flat_hash_set buffers; +CommandBufferCmd::BufferUseVector CaseCmd::buffers() { + absl::flat_hash_set buffers; buffers.emplace(index_, MemoryAccess::kRead); for (auto& branch : branches_commands_) { buffers.insert(branch.buffers().begin(), branch.buffers().end()); @@ -1039,8 +1041,8 @@ absl::Status ForCmd::Record(const Thunk::ExecuteParams& execute_params, bool ForCmd::force_update() { return body_commands_.force_update(); } -CommandBufferCmd::BufferUsageVector ForCmd::buffers() { - absl::flat_hash_set buffers; +CommandBufferCmd::BufferUseVector ForCmd::buffers() { + absl::flat_hash_set buffers; buffers.emplace(loop_counter_, MemoryAccess::kWrite); buffers.insert(body_commands_.buffers().begin(), body_commands_.buffers().end()); @@ -1089,8 +1091,8 @@ bool WhileCmd::force_update() { return (cond_commands_.force_update() || body_commands_.force_update()); } -CommandBufferCmd::BufferUsageVector WhileCmd::buffers() { - absl::flat_hash_set buffers; +CommandBufferCmd::BufferUseVector WhileCmd::buffers() { + absl::flat_hash_set buffers; buffers.emplace(pred_, MemoryAccess::kWrite); buffers.insert(cond_commands_.buffers().begin(), cond_commands_.buffers().end()); @@ -1152,7 +1154,7 @@ absl::Status GemmCmd::Record(const Thunk::ExecuteParams& execute_params, }); } -CommandBufferCmd::BufferUsageVector GemmCmd::buffers() { +CommandBufferCmd::BufferUseVector GemmCmd::buffers() { return {{lhs_buffer_, MemoryAccess::kRead}, {rhs_buffer_, MemoryAccess::kRead}, {output_buffer_, MemoryAccess::kWrite}, @@ -1292,8 +1294,8 @@ absl::Status CublasLtCmd::Record(const Thunk::ExecuteParams& execute_params, }); } -CommandBufferCmd::BufferUsageVector CublasLtCmd::buffers() { - BufferUsageVector buffer_usage; +CommandBufferCmd::BufferUseVector CublasLtCmd::buffers() { + BufferUseVector buffer_usage; buffer_usage.reserve(13); buffer_usage.push_back({a_buffer_, MemoryAccess::kRead}); buffer_usage.push_back({b_buffer_, MemoryAccess::kRead}); @@ -1366,8 +1368,8 @@ absl::Status CuDnnCmd::Record(const Thunk::ExecuteParams& execute_params, }); } -CommandBufferCmd::BufferUsageVector CuDnnCmd::buffers() { - CommandBufferCmd::BufferUsageVector buffer_usage; +CommandBufferCmd::BufferUseVector CuDnnCmd::buffers() { + CommandBufferCmd::BufferUseVector buffer_usage; buffer_usage.reserve(args_.size()); for (int i = 0; i < args_.size() - 1; ++i) { buffer_usage.push_back({args_[i], MemoryAccess::kRead}); @@ -1524,8 +1526,8 @@ absl::Status CustomCallCmd::RecordXlaFfiCall( *nested_cmd); } -CommandBufferCmd::BufferUsageVector CustomCallCmd::buffers() { - CommandBufferCmd::BufferUsageVector buffer_usage; +CommandBufferCmd::BufferUseVector CustomCallCmd::buffers() { + CommandBufferCmd::BufferUseVector buffer_usage; for (auto& slices : {operands_, results_}) { for (const std::optional& slice : slices) { if (!slice.has_value()) continue; @@ -1558,7 +1560,7 @@ absl::Status BarrierCmd::Record(const Thunk::ExecuteParams& execute_params, return absl::OkStatus(); } -BarrierCmd::BufferUsageVector BarrierCmd::buffers() { return {}; } +BarrierCmd::BufferUseVector BarrierCmd::buffers() { return {}; } //===----------------------------------------------------------------------===// // CollectiveCmd @@ -1676,8 +1678,8 @@ absl::Status AllReduceCmd::Record(const Thunk::ExecuteParams& execute_params, }); } -CommandBufferCmd::BufferUsageVector AllReduceCmd::buffers() { - BufferUsageVector buffer_usage; +CommandBufferCmd::BufferUseVector AllReduceCmd::buffers() { + BufferUseVector buffer_usage; for (auto& buffer : buffers_) { buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead); buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite); @@ -1743,8 +1745,8 @@ absl::Status ReduceScatterCmd::Record( }); } -CommandBufferCmd::BufferUsageVector ReduceScatterCmd::buffers() { - BufferUsageVector buffer_usage; +CommandBufferCmd::BufferUseVector ReduceScatterCmd::buffers() { + BufferUseVector buffer_usage; for (auto& buffer : buffers_) { buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead); buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite); @@ -1807,8 +1809,8 @@ absl::Status AllToAllCmd::Record(const Thunk::ExecuteParams& execute_params, }); } -CommandBufferCmd::BufferUsageVector AllToAllCmd::buffers() { - BufferUsageVector buffer_usage; +CommandBufferCmd::BufferUseVector AllToAllCmd::buffers() { + BufferUseVector buffer_usage; for (auto& buffer : buffers_) { buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead); buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite); @@ -1870,8 +1872,8 @@ absl::Status AllGatherCmd::Record(const Thunk::ExecuteParams& execute_params, }); } -CommandBufferCmd::BufferUsageVector AllGatherCmd::buffers() { - BufferUsageVector buffer_usage; +CommandBufferCmd::BufferUseVector AllGatherCmd::buffers() { + BufferUseVector buffer_usage; for (auto& buffer : buffers_) { buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead); buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite); @@ -1935,8 +1937,8 @@ absl::Status CollectiveBroadcastCmd::Record( }); } -CommandBufferCmd::BufferUsageVector CollectiveBroadcastCmd::buffers() { - BufferUsageVector buffer_usage; +CommandBufferCmd::BufferUseVector CollectiveBroadcastCmd::buffers() { + BufferUseVector buffer_usage; for (auto& buffer : buffers_) { buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead); buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite); @@ -2176,14 +2178,15 @@ absl::Status DynamicSliceFusionCmd::Record( *nested_command_buffer); } -CommandBufferCmd::BufferUsageVector DynamicSliceFusionCmd::buffers() { - CommandBufferCmd::BufferUsageVector buffers; +CommandBufferCmd::BufferUseVector DynamicSliceFusionCmd::buffers() { + CommandBufferCmd::BufferUseVector buffers; auto embed_buffers = embedded_commands_->buffers(); for (auto buffer_usage : embed_buffers) { - CHECK(embeded_to_origin_slice_map_[buffer_usage.slice.index()].has_value()); + CHECK( + embeded_to_origin_slice_map_[buffer_usage.slice().index()].has_value()); buffers.emplace_back( - embeded_to_origin_slice_map_[buffer_usage.slice.index()].value(), - buffer_usage.access); + embeded_to_origin_slice_map_[buffer_usage.slice().index()].value(), + buffer_usage.access()); } return buffers; } diff --git a/xla/service/gpu/runtime/command_buffer_cmd.h b/xla/service/gpu/runtime/command_buffer_cmd.h index eb08838644a6e..a1db49630094d 100644 --- a/xla/service/gpu/runtime/command_buffer_cmd.h +++ b/xla/service/gpu/runtime/command_buffer_cmd.h @@ -49,6 +49,7 @@ limitations under the License. #include "xla/service/gpu/runtime/dynamic_slice_thunk.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/thunk.h" +#include "xla/runtime/buffer_use.h" #include "xla/shape.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" @@ -118,28 +119,7 @@ class CommandBufferCmd { : cmd_type_(cmd_type), execution_stream_id_(execution_stream_id) {} virtual ~CommandBufferCmd() = default; - enum class MemoryAccess { kRead, kWrite }; - - // BufferUsage tracks memory access type for a buffer slice, so that we can - // correctly insert command buffer barriers to avoid read/write conflicts. - struct BufferUsage { - BufferUsage(BufferAllocation::Slice slice, MemoryAccess access) - : slice(slice), access(access) {} - - template - friend H AbslHashValue(H h, const BufferUsage& buffer) { - return H::combine(std::move(h), buffer.slice, buffer.access); - } - - bool operator==(const BufferUsage& other) const { - return slice == other.slice && access == other.access; - } - - BufferAllocation::Slice slice; - MemoryAccess access; - }; - - using BufferUsageVector = absl::InlinedVector; + using BufferUseVector = absl::InlinedVector; // A base class for externally managed command state. // @@ -244,7 +224,7 @@ class CommandBufferCmd { // Returns all buffers used by the cmd. These will be used to track cmd // updates, thus they need to be consistent across calls to the function. - virtual BufferUsageVector buffers() = 0; + virtual BufferUseVector buffers() = 0; // Returns true if command implemented as a nested command buffer. virtual bool IsNestedCommandBuffer() const { return false; } @@ -355,7 +335,7 @@ class CommandBufferCmdSequence { RecordMode mode = RecordMode::kExclusive); // Returns buffers referenced by commands in this sequence. - const absl::flat_hash_set& buffers() const; + const absl::flat_hash_set& buffers() const; // Returns buffer allocations indices referenced by commands in this sequence. const absl::flat_hash_set& allocs_indices() const; @@ -382,16 +362,16 @@ class CommandBufferCmdSequence { // Functions for tracking buffer usage of recorded commands and figuring out // when the next command requires a barrier for correctness. bool HasConflicts(ExecutionStreamId execution_stream_id, - const CommandBufferCmd::BufferUsageVector& buffers); + const CommandBufferCmd::BufferUseVector& buffers); void TrackBuffers(ExecutionStreamId execution_stream_id, - const CommandBufferCmd::BufferUsageVector& buffers); + const CommandBufferCmd::BufferUseVector& buffers); void ClearTrackedBuffers(ExecutionStreamId execution_stream_id); SynchronizationMode synchronization_mode_; std::vector commands_; // Buffers referenced by commands in this sequence. - absl::flat_hash_set buffers_; + absl::flat_hash_set buffers_; // Buffer allocations indices referenced by commands in this sequence. absl::flat_hash_set allocs_indices_; @@ -418,7 +398,7 @@ class CommandBufferCmdSequence { class TracedCommandBuffer : public CommandBufferCmd::State { public: explicit TracedCommandBuffer(const CommandBufferCmd* trace_cmd, - CommandBufferCmd::BufferUsageVector buffers, + CommandBufferCmd::BufferUseVector buffers, int64_t capacity = 16); // Returns cached command buffer traced using the same buffer addresses or @@ -476,7 +456,7 @@ class ComputationIdCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: BufferAllocation::Slice dest_; @@ -503,8 +483,8 @@ class LaunchCmd : public CommandBufferCmd { public: LaunchCmd(ExecutionStreamId execution_stream_id, std::string kernel_name, absl::Span args, - absl::Span args_access, LaunchDimensions dims, - int64_t shmem_bytes); + absl::Span args_access, + LaunchDimensions dims, int64_t shmem_bytes); absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; @@ -513,12 +493,12 @@ class LaunchCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: std::string kernel_name_; std::vector args_; - std::vector args_access_; + std::vector args_access_; LaunchDimensions dims_; int64_t shmem_bytes_; @@ -537,7 +517,7 @@ class CustomKernelLaunchCmd : public CommandBufferCmd { public: CustomKernelLaunchCmd(ExecutionStreamId execution_stream_id, absl::Span args, - absl::Span args_access, + absl::Span args_access, CustomKernel custom_kernel); absl::Status Initialize(const Thunk::InitializeParams& params, @@ -547,11 +527,11 @@ class CustomKernelLaunchCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: std::vector args_; - std::vector args_access_; + std::vector args_access_; CustomKernel custom_kernel_; // Command sequence can be recorded concurrently for multiple command buffers @@ -575,7 +555,7 @@ class MemcpyDeviceToDeviceCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: BufferAllocation::Slice dst_; @@ -596,7 +576,7 @@ class MemzeroCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: BufferAllocation::Slice dst_; @@ -615,7 +595,7 @@ class Memset32Cmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: BufferAllocation::Slice dst_; @@ -640,7 +620,7 @@ class IfCmd : public CommandBufferCmd { bool force_update() override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: BufferAllocation::Slice pred_; @@ -666,7 +646,7 @@ class IfElseCmd : public CommandBufferCmd { bool force_update() override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: BufferAllocation::Slice pred_; @@ -692,7 +672,7 @@ class CaseCmd : public CommandBufferCmd { bool force_update() override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: BufferAllocation::Slice index_; @@ -718,7 +698,7 @@ class ForCmd : public CommandBufferCmd { bool force_update() override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: int32_t num_iterations_; @@ -745,7 +725,7 @@ class WhileCmd : public CommandBufferCmd { bool force_update() override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: BufferAllocation::Slice pred_; @@ -772,7 +752,7 @@ class GemmCmd : public TracedCommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; bool IsNestedCommandBuffer() const final { return true; } @@ -814,7 +794,7 @@ class CublasLtCmd : public TracedCommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; bool IsNestedCommandBuffer() const final { return true; } @@ -867,7 +847,7 @@ class CuDnnCmd : public TracedCommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; bool IsNestedCommandBuffer() const final { return true; } @@ -920,7 +900,7 @@ class CustomCallCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; bool IsNestedCommandBuffer() const final { return true; } private: @@ -969,7 +949,7 @@ class BarrierCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: const ExecutionStreamId from_stream_id_; @@ -1039,7 +1019,7 @@ class AllReduceCmd : public CollectiveCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; AsyncStreamKind GetAsyncStreamKind() override { return AsyncStreamKind::kCollective; @@ -1065,7 +1045,7 @@ class ReduceScatterCmd : public CollectiveCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; AsyncStreamKind GetAsyncStreamKind() override { return AsyncStreamKind::kCollective; @@ -1091,7 +1071,7 @@ class AllToAllCmd : public CollectiveCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; AsyncStreamKind GetAsyncStreamKind() override { return AsyncStreamKind::kCollective; @@ -1117,7 +1097,7 @@ class AllGatherCmd : public CollectiveCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; AsyncStreamKind GetAsyncStreamKind() override { return AsyncStreamKind::kCollective; @@ -1142,7 +1122,7 @@ class CollectiveBroadcastCmd : public CollectiveCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: std::vector buffers_; @@ -1175,7 +1155,7 @@ class DynamicSliceFusionCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; bool force_update() override; diff --git a/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc b/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc index de9734682ff87..442eb5d1f21a6 100644 --- a/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/gpu/runtime/wait_for_streams_thunk.h" #include "xla/service/gpu/runtime/while_thunk.h" +#include "xla/runtime/buffer_use.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -62,13 +63,14 @@ static absl::Status AppendCommands( //===----------------------------------------------------------------------===// using Command = std::unique_ptr; +using xla::BufferUse; static auto ArgsAccess(const std::vector& written) { - absl::InlinedVector args_access; + absl::InlinedVector args_access; args_access.reserve(written.size()); for (bool w : written) { - args_access.push_back(w ? CommandBufferCmd::MemoryAccess::kWrite - : CommandBufferCmd::MemoryAccess::kRead); + args_access.push_back(w ? BufferUse::MemoryAccess::kWrite + : BufferUse::MemoryAccess::kRead); } return args_access; } diff --git a/xla/service/gpu/runtime/command_buffer_cmd_test.cc b/xla/service/gpu/runtime/command_buffer_cmd_test.cc index 90b6e0666c8ad..45f24e3df09e1 100644 --- a/xla/service/gpu/runtime/command_buffer_cmd_test.cc +++ b/xla/service/gpu/runtime/command_buffer_cmd_test.cc @@ -45,9 +45,9 @@ limitations under the License. namespace xla::gpu { -using BufferUsage = CommandBufferCmd::BufferUsage; -using BufferUsageVector = CommandBufferCmd::BufferUsageVector; -using MemoryAccess = CommandBufferCmd::MemoryAccess; +using xla::BufferUse; +using BufferUseVector = CommandBufferCmd::BufferUseVector; +using MemoryAccess = BufferUse::MemoryAccess; static se::StreamExecutor* GpuExecutor() { auto name = @@ -65,7 +65,7 @@ static constexpr auto s1 = ExecutionStreamId(1); // buffer usage vector to the command buffer cmd sequence. struct TestOnlyCommandBufferCmd : public CommandBufferCmd { TestOnlyCommandBufferCmd(ExecutionStreamId execution_stream_id, - BufferUsageVector buffer_usage) + BufferUseVector buffer_usage) : CommandBufferCmd(CommandBufferCmdType::kUnknownCmd, execution_stream_id), buffer_usage(buffer_usage) {} @@ -75,9 +75,9 @@ struct TestOnlyCommandBufferCmd : public CommandBufferCmd { return absl::OkStatus(); } - BufferUsageVector buffers() override { return buffer_usage; } + BufferUseVector buffers() override { return buffer_usage; } - BufferUsageVector buffer_usage; + BufferUseVector buffer_usage; }; class FakeCmd : public CommandBufferCmd { @@ -91,7 +91,7 @@ class FakeCmd : public CommandBufferCmd { se::CommandBuffer* command_buffer) override { return absl::OkStatus(); } - BufferUsageVector buffers() override { return BufferUsageVector{}; } + BufferUseVector buffers() override { return BufferUseVector{}; } }; TEST(CommandBufferCmdTest, SerializeExecution) { @@ -101,13 +101,13 @@ TEST(CommandBufferCmdTest, SerializeExecution) { auto slice1 = BufferAllocation::Slice(&alloc0, 50, 100); // Reads from overlapping slices do not require barriers by default. - auto use0 = BufferUsage(slice0, MemoryAccess::kRead); - auto use1 = BufferUsage(slice1, MemoryAccess::kRead); + auto use0 = BufferUse(slice0, BufferUse::kRead); + auto use1 = BufferUse(slice1, BufferUse::kRead); CommandBufferCmdSequence commands( CommandBufferCmdSequence::SynchronizationMode::kSerialize); - commands.Emplace(s0, BufferUsageVector{use0}); - commands.Emplace(s0, BufferUsageVector{use1}); + commands.Emplace(s0, BufferUseVector{use0}); + commands.Emplace(s0, BufferUseVector{use1}); ASSERT_EQ(commands.barriers().size(), 2); EXPECT_EQ(commands.barriers().at(0), false); @@ -121,12 +121,12 @@ TEST(CommandBufferCmdTest, NoReadBarrier) { auto slice1 = BufferAllocation::Slice(&alloc0, 50, 100); // Reads from overlapping slices do not require barriers. - auto use0 = BufferUsage(slice0, MemoryAccess::kRead); - auto use1 = BufferUsage(slice1, MemoryAccess::kRead); + auto use0 = BufferUse(slice0, BufferUse::kRead); + auto use1 = BufferUse(slice1, BufferUse::kRead); CommandBufferCmdSequence commands; - commands.Emplace(s0, BufferUsageVector{use0}); - commands.Emplace(s0, BufferUsageVector{use1}); + commands.Emplace(s0, BufferUseVector{use0}); + commands.Emplace(s0, BufferUseVector{use1}); ASSERT_EQ(commands.barriers().size(), 2); EXPECT_EQ(commands.barriers().at(0), false); @@ -140,12 +140,12 @@ TEST(CommandBufferCmdTest, NoWriteBarrier) { auto slice0 = BufferAllocation::Slice(&alloc0, 0, 100); auto slice1 = BufferAllocation::Slice(&alloc0, 200, 100); - auto use0 = BufferUsage(slice0, MemoryAccess::kWrite); - auto use1 = BufferUsage(slice1, MemoryAccess::kWrite); + auto use0 = BufferUse(slice0, BufferUse::kWrite); + auto use1 = BufferUse(slice1, BufferUse::kWrite); CommandBufferCmdSequence commands; - commands.Emplace(s0, BufferUsageVector{use0}); - commands.Emplace(s0, BufferUsageVector{use1}); + commands.Emplace(s0, BufferUseVector{use0}); + commands.Emplace(s0, BufferUseVector{use1}); ASSERT_EQ(commands.barriers().size(), 2); EXPECT_EQ(commands.barriers().at(0), false); @@ -160,14 +160,14 @@ TEST(CommandBufferCmdTest, WriteConflictBarrier) { // Reads from overlapping slices can be done in parallel, and before a write // into overlapping slice we need to insert a barrier. - auto use0 = BufferUsage(slice0, MemoryAccess::kRead); - auto use1 = BufferUsage(slice0, MemoryAccess::kRead); - auto use2 = BufferUsage(slice1, MemoryAccess::kWrite); + auto use0 = BufferUse(slice0, BufferUse::kRead); + auto use1 = BufferUse(slice0, BufferUse::kRead); + auto use2 = BufferUse(slice1, BufferUse::kWrite); CommandBufferCmdSequence commands; - commands.Emplace(s0, BufferUsageVector{use0}); - commands.Emplace(s0, BufferUsageVector{use1}); - commands.Emplace(s0, BufferUsageVector{use2}); + commands.Emplace(s0, BufferUseVector{use0}); + commands.Emplace(s0, BufferUseVector{use1}); + commands.Emplace(s0, BufferUseVector{use2}); ASSERT_EQ(commands.barriers().size(), 3); EXPECT_EQ(commands.barriers().at(0), false); @@ -183,12 +183,12 @@ TEST(CommandBufferCmdTest, NoWriteConflictsAcrossStreams) { // Read and write happens on different execution streams and we do not insert // any automatic barriers between streams. - auto use0 = BufferUsage(slice0, MemoryAccess::kRead); - auto use1 = BufferUsage(slice1, MemoryAccess::kWrite); + auto use0 = BufferUse(slice0, BufferUse::kRead); + auto use1 = BufferUse(slice1, BufferUse::kWrite); CommandBufferCmdSequence commands; - commands.Emplace(s0, BufferUsageVector{use0}); - commands.Emplace(s1, BufferUsageVector{use1}); + commands.Emplace(s0, BufferUseVector{use0}); + commands.Emplace(s1, BufferUseVector{use1}); ASSERT_EQ(commands.barriers().size(), 2); EXPECT_EQ(commands.barriers().at(0), false); @@ -348,8 +348,7 @@ TEST(CommandBufferCmdTest, LaunchCmd) { BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); auto args = {slice_a, slice_a, slice_b}; // b = a + a - auto args_access = {MemoryAccess::kRead, MemoryAccess::kRead, - MemoryAccess::kWrite}; + auto args_access = {BufferUse::kRead, MemoryAccess::kRead, BufferUse::kWrite}; // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence commands; @@ -420,9 +419,9 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) { BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation alloc1(/*index=*/1, /*size=*/1024, /*color=*/0); - CommandBufferCmd::BufferUsageVector buffers = { - {BufferAllocation::Slice(&alloc0, 0, 1024), MemoryAccess::kRead}, - {BufferAllocation::Slice(&alloc1, 0, 1024), MemoryAccess::kWrite}}; + CommandBufferCmd::BufferUseVector buffers = { + {BufferAllocation::Slice(&alloc0, 0, 1024), BufferUse::kRead}, + {BufferAllocation::Slice(&alloc1, 0, 1024), BufferUse::kWrite}}; TracedCommandBuffer traced_cmd_buffer(&traced_cmd, buffers, /*capacity=*/trace_cache_size); @@ -510,9 +509,9 @@ static void BM_GetOrTraceCommandBuffer(benchmark::State& state) { BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation alloc1(/*index=*/1, /*size=*/1024, /*color=*/0); - CommandBufferCmd::BufferUsageVector buffers = { - {BufferAllocation::Slice(&alloc0, 0, 1024), MemoryAccess::kRead}, - {BufferAllocation::Slice(&alloc1, 0, 1024), MemoryAccess::kWrite}}; + CommandBufferCmd::BufferUseVector buffers = { + {BufferAllocation::Slice(&alloc0, 0, 1024), BufferUse::kRead}, + {BufferAllocation::Slice(&alloc1, 0, 1024), BufferUse::kWrite}}; se::DeviceMemoryBase mem0(reinterpret_cast(0x01234567)); se::DeviceMemoryBase mem1(reinterpret_cast(0x12345670)); diff --git a/xla/service/gpu/runtime/command_buffer_thunk_test.cc b/xla/service/gpu/runtime/command_buffer_thunk_test.cc index 1ca4b248b24a1..9d6147f448851 100644 --- a/xla/service/gpu/runtime/command_buffer_thunk_test.cc +++ b/xla/service/gpu/runtime/command_buffer_thunk_test.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/platform_util.h" #include "xla/service/service_executable_run_options.h" +#include "xla/runtime/buffer_use.h" #include "xla/shape_util.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" @@ -68,7 +69,7 @@ limitations under the License. namespace xla::gpu { -using MemoryAccess = CommandBufferCmd::MemoryAccess; +using MemoryAccess = BufferUse::MemoryAccess; using KernelArgsPacking = se::MultiKernelLoaderSpec::KernelArgsPacking; namespace {