Skip to content

Commit

Permalink
[xla:cpu] Migrate AllGather to unified collectives API
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712111435
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Jan 4, 2025
1 parent 51e5004 commit ac6e71f
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 39 deletions.
1 change: 1 addition & 0 deletions xla/backends/cpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ cc_library(
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"//xla/service:collective_ops_utils",
Expand Down
7 changes: 5 additions & 2 deletions xla/backends/cpu/runtime/all_gather_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/service/buffer_assignment.h"
Expand Down Expand Up @@ -77,11 +78,13 @@ tsl::AsyncValueRef<AllGatherThunk::ExecuteEvent> AllGatherThunk::Execute(
return ExecuteWithCommunicator(
params.collective_params,
[&](const RendezvousKey& key, CollectivesCommunicator& comm) {
CpuCollectives::Executor executor(key, DefaultCollectiveTimeout());

for (int32_t i = 0; i < data.source.size(); ++i) {
const Shape& shape = source_shape(i);
TF_RETURN_IF_ERROR(comm.AllGather(
key, ShapeUtil::ByteSizeOf(shape), data.source[i].opaque(),
data.destination[i].opaque(), DefaultCollectiveTimeout()));
data.source[i], data.destination[i], shape.element_type(),
ShapeUtil::ElementsIn(shape), executor));
}
return absl::OkStatus();
});
Expand Down
18 changes: 9 additions & 9 deletions xla/pjrt/cpu/gloo_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,19 +290,19 @@ absl::Status GlooCollectivesCommunicator::AllToAll(
return absl::OkStatus();
}

absl::Status GlooCollectivesCommunicator::AllGather(const RendezvousKey& key,
size_t chunk_bytes,
const void* input_buffer,
void* output_buffer,
absl::Duration timeout) {
absl::Status GlooCollectivesCommunicator::AllGather(
se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count, const Executor& executor) {
uint32_t tag = 0; // TODO(phawkins): use better tags.

TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor));
size_t chunk_bytes = count * primitive_util::ByteWidth(dtype);

gloo::AllgatherOptions options(context_);
options.setTag(tag);
options.setTimeout(absl::ToChronoMilliseconds(timeout));
options.setInput(reinterpret_cast<char*>(const_cast<void*>(input_buffer)),
chunk_bytes);
options.setOutput(reinterpret_cast<char*>(output_buffer),
options.setTimeout(absl::ToChronoMilliseconds(cpu_executor->timeout()));
options.setInput(reinterpret_cast<char*>(send_buffer.opaque()), chunk_bytes);
options.setOutput(reinterpret_cast<char*>(recv_buffer.opaque()),
chunk_bytes * context_->size);

try {
Expand Down
6 changes: 3 additions & 3 deletions xla/pjrt/cpu/gloo_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ class GlooCollectivesCommunicator : public CollectivesCommunicator {
absl::Span<const void* const> input_buffers,
absl::Span<void* const> output_buffers,
absl::Duration timeout) override;
absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes,
const void* input_buffer, void* output_buffer,
absl::Duration timeout) override;
absl::Status AllGather(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer, PrimitiveType dtype,
size_t count, const Executor& executor) override;
absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
Expand Down
15 changes: 7 additions & 8 deletions xla/pjrt/cpu/mpi_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,13 @@ absl::Status MpiCollectivesCommunicator::AllToAll(
return absl::OkStatus();
}

absl::Status MpiCollectivesCommunicator::AllGather(const RendezvousKey& key,
size_t chunk_bytes,
const void* input_buffer,
void* output_buffer,
absl::Duration timeout) {
return MpiErrorToAbslStatus(MPI_Allgather(input_buffer, chunk_bytes, MPI_BYTE,
output_buffer, chunk_bytes,
MPI_BYTE, comm_));
absl::Status MpiCollectivesCommunicator::AllGather(
se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count, const Executor& executor) {
TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype));
return MpiErrorToAbslStatus(MPI_Allgather(send_buffer.opaque(), count, type,
recv_buffer.opaque(), count, type,
comm_));
}

absl::Status MpiCollectivesCommunicator::ReduceScatter(
Expand Down
6 changes: 3 additions & 3 deletions xla/pjrt/cpu/mpi_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ class MpiCollectivesCommunicator : public CollectivesCommunicator {
absl::Span<const void* const> input_buffers,
absl::Span<void* const> output_buffers,
absl::Duration timeout) override;
absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes,
const void* input_buffer, void* output_buffer,
absl::Duration timeout) override;
absl::Status AllGather(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer, PrimitiveType dtype,
size_t count, const Executor& executor) override;
absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
Expand Down
7 changes: 4 additions & 3 deletions xla/service/cpu/collectives_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,10 @@ class CollectivesCommunicator {
absl::Duration timeout) = 0;

// Performs an all-gather.
virtual absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes,
const void* input_buffer, void* output_buffer,
absl::Duration timeout) = 0;
virtual absl::Status AllGather(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
const Executor& executor) = 0;

// Performs a reduce-scatter
virtual absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer,
Expand Down
10 changes: 7 additions & 3 deletions xla/service/cpu/cpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,13 @@ void AllGatherImpl(const ExecutableRunOptions* run_options,

auto communicator =
collectives->GetCommunicator(rendezvous_key.global_devices, rank).value();
TF_CHECK_OK(communicator->AllGather(rendezvous_key, buffer_size,
source_buffer, destination_buffer,
DefaultCollectiveTimeout()));

se::DeviceMemoryBase input_buffer_data(source_buffer, buffer_size);
se::DeviceMemoryBase output_buffer_data(destination_buffer, buffer_size);

CpuCollectives::Executor executor(rendezvous_key, DefaultCollectiveTimeout());
TF_CHECK_OK(communicator->AllGather(input_buffer_data, output_buffer_data, U8,
buffer_size, executor));
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY
Expand Down
13 changes: 8 additions & 5 deletions xla/service/cpu/in_process_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,12 +514,15 @@ absl::Status InProcessCollectivesCommunicator::AllToAll(
}

absl::Status InProcessCollectivesCommunicator::AllGather(
const RendezvousKey& key, size_t chunk_bytes, const void* input_buffer,
void* output_buffer, absl::Duration timeout) {
se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count, const Executor& executor) {
TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor));
const RendezvousKey& key = cpu_executor->rendezvous_key();

AllGatherParticipantData participant(key, rank_);
participant.chunk_size = chunk_bytes;
participant.source_buffer = input_buffer;
participant.destination_buffer = output_buffer;
participant.chunk_size = count * primitive_util::ByteWidth(dtype);
participant.source_buffer = send_buffer.opaque();
participant.destination_buffer = recv_buffer.opaque();
auto make_cpu_rendezvous = [](const RendezvousKey& k) {
return std::make_unique<CpuAllGatherRendezvous>(k);
};
Expand Down
6 changes: 3 additions & 3 deletions xla/service/cpu/in_process_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ class InProcessCollectivesCommunicator : public CollectivesCommunicator {
absl::Span<void* const> output_buffers,
absl::Duration timeout) override;

absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes,
const void* input_buffer, void* output_buffer,
absl::Duration timeout) override;
absl::Status AllGather(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer, PrimitiveType dtype,
size_t count, const Executor& executor) override;

absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer,
Expand Down

0 comments on commit ac6e71f

Please sign in to comment.