Skip to content

Commit

Permalink
[xla:cpu] Migrate ReduceScatter to unified collectives API
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712095466
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Jan 4, 2025
1 parent e94e29f commit a2316c3
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 62 deletions.
1 change: 1 addition & 0 deletions xla/backends/cpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ cc_library(
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"//xla/service:collective_ops_utils",
Expand Down
9 changes: 6 additions & 3 deletions xla/backends/cpu/runtime/reduce_scatter_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/primitive_util.h"
Expand Down Expand Up @@ -90,13 +91,15 @@ ReduceScatterThunk::Execute(const ExecuteParams& params) {
return ExecuteWithCommunicator(
params.collective_params,
[&](const RendezvousKey& key, CollectivesCommunicator& comm) {
CpuCollectives::Executor executor(key, DefaultCollectiveTimeout());

for (int32_t i = 0; i < data.source.size(); ++i) {
const Shape& shape = destination_shape(i);
TF_RETURN_IF_ERROR(comm.ReduceScatter(
key, reduction_kind_, shape.element_type(),
ShapeUtil::ElementsIn(shape), data.source[i].opaque(),
data.destination[i].opaque(), DefaultCollectiveTimeout()));
data.source[i], data.destination[i], shape.element_type(),
ShapeUtil::ElementsIn(shape), reduction_kind_, executor));
}

return absl::OkStatus();
});
}
Expand Down
1 change: 0 additions & 1 deletion xla/core/collectives/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ class Communicator {
se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
ReductionKind reduction_kind,

const Executor& executor) = 0;

// Gather `count` values from all devices into `recv_buffer`, receiving data
Expand Down
50 changes: 25 additions & 25 deletions xla/pjrt/cpu/gloo_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -364,74 +364,74 @@ absl::Status ReduceScatterHelper(std::shared_ptr<gloo::Context> context,
}

absl::Status GlooCollectivesCommunicator::ReduceScatter(
const RendezvousKey& key, ReductionKind reduction_kind,
PrimitiveType element_type, size_t chunk_elems, const void* input_buffer,
void* output_buffer, absl::Duration timeout) {
size_t chunk_bytes = chunk_elems * primitive_util::ByteWidth(element_type);
se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count, ReductionKind reduction_kind,
const Executor& executor) {
size_t chunk_bytes = count * primitive_util::ByteWidth(dtype);
std::unique_ptr<char[]> temp(new char[chunk_bytes * context_->size]);
std::memcpy(temp.get(), input_buffer, chunk_bytes * context_->size);
switch (element_type) {
std::memcpy(temp.get(), send_buffer.opaque(), chunk_bytes * context_->size);
switch (dtype) {
case S8:
TF_RETURN_IF_ERROR(ReduceScatterHelper<int8_t>(context_, reduction_kind,
temp.get(), chunk_elems));
temp.get(), count));
break;
case PRED:
case U8:
TF_RETURN_IF_ERROR(ReduceScatterHelper<uint8_t>(context_, reduction_kind,
temp.get(), chunk_elems));
temp.get(), count));
break;
case S16:
TF_RETURN_IF_ERROR(ReduceScatterHelper<int16_t>(context_, reduction_kind,
temp.get(), chunk_elems));
temp.get(), count));
break;
case U16:
TF_RETURN_IF_ERROR(ReduceScatterHelper<uint16_t>(
context_, reduction_kind, temp.get(), chunk_elems));
TF_RETURN_IF_ERROR(ReduceScatterHelper<uint16_t>(context_, reduction_kind,
temp.get(), count));
break;
case S32:
TF_RETURN_IF_ERROR(ReduceScatterHelper<int32_t>(context_, reduction_kind,
temp.get(), chunk_elems));
temp.get(), count));
break;
case U32:
TF_RETURN_IF_ERROR(ReduceScatterHelper<uint32_t>(
context_, reduction_kind, temp.get(), chunk_elems));
TF_RETURN_IF_ERROR(ReduceScatterHelper<uint32_t>(context_, reduction_kind,
temp.get(), count));
break;
case S64:
TF_RETURN_IF_ERROR(ReduceScatterHelper<int64_t>(context_, reduction_kind,
temp.get(), chunk_elems));
temp.get(), count));
break;
case U64:
TF_RETURN_IF_ERROR(ReduceScatterHelper<uint64_t>(
context_, reduction_kind, temp.get(), chunk_elems));
TF_RETURN_IF_ERROR(ReduceScatterHelper<uint64_t>(context_, reduction_kind,
temp.get(), count));
break;
case BF16:
TF_RETURN_IF_ERROR(ReduceScatterHelper<bfloat16>(
context_, reduction_kind, temp.get(), chunk_elems));
TF_RETURN_IF_ERROR(ReduceScatterHelper<bfloat16>(context_, reduction_kind,
temp.get(), count));
break;
case F16:
TF_RETURN_IF_ERROR(ReduceScatterHelper<gloo::float16>(
context_, reduction_kind, temp.get(), chunk_elems));
context_, reduction_kind, temp.get(), count));
break;
case F32:
TF_RETURN_IF_ERROR(ReduceScatterHelper<float>(context_, reduction_kind,
temp.get(), chunk_elems));
temp.get(), count));
break;
case F64:
TF_RETURN_IF_ERROR(ReduceScatterHelper<double>(context_, reduction_kind,
temp.get(), chunk_elems));
temp.get(), count));
break;
case C64:
TF_RETURN_IF_ERROR(ReduceScatterHelper<std::complex<float>>(
context_, reduction_kind, temp.get(), chunk_elems));
context_, reduction_kind, temp.get(), count));
break;
case C128:
TF_RETURN_IF_ERROR(ReduceScatterHelper<std::complex<double>>(
context_, reduction_kind, temp.get(), chunk_elems));
context_, reduction_kind, temp.get(), count));
break;
default:
return absl::InvalidArgumentError("Unknown datatype in reducescatter");
}
std::memcpy(output_buffer, temp.get(), chunk_bytes);
std::memcpy(recv_buffer.opaque(), temp.get(), chunk_bytes);
return absl::OkStatus();
}

Expand Down
8 changes: 4 additions & 4 deletions xla/pjrt/cpu/gloo_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ class GlooCollectivesCommunicator : public CollectivesCommunicator {
absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes,
const void* input_buffer, void* output_buffer,
absl::Duration timeout) override;
absl::Status ReduceScatter(const RendezvousKey& key,
absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
ReductionKind reduction_kind,
PrimitiveType element_type, size_t chunk_elems,
const void* input_buffer, void* output_buffer,
absl::Duration timeout) override;
const Executor& executor) override;

private:
std::shared_ptr<gloo::Context> context_;
Expand Down
15 changes: 8 additions & 7 deletions xla/pjrt/cpu/mpi_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,16 @@ absl::Status MpiCollectivesCommunicator::AllGather(const RendezvousKey& key,
}

absl::Status MpiCollectivesCommunicator::ReduceScatter(
const RendezvousKey& key, ReductionKind reduction_kind,
PrimitiveType element_type, size_t chunk_elems, const void* input_buffer,
void* output_buffer, absl::Duration timeout) {
se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count, ReductionKind reduction_kind,
const Executor& executor) {
const int size = mpi_size_;
std::vector<int> recvcounts(size, chunk_elems);
TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(element_type));
std::vector<int> recvcounts(size, count);
TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype));
TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type));
return MpiErrorToAbslStatus(MPI_Reduce_scatter(
input_buffer, output_buffer, recvcounts.data(), type, op, comm_));
return MpiErrorToAbslStatus(
MPI_Reduce_scatter(send_buffer.opaque(), recv_buffer.opaque(),
recvcounts.data(), type, op, comm_));
}

void MpiCollectives::Init() {
Expand Down
8 changes: 4 additions & 4 deletions xla/pjrt/cpu/mpi_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ class MpiCollectivesCommunicator : public CollectivesCommunicator {
absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes,
const void* input_buffer, void* output_buffer,
absl::Duration timeout) override;
absl::Status ReduceScatter(const RendezvousKey& key,
absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
ReductionKind reduction_kind,
PrimitiveType element_type, size_t chunk_elems,
const void* input_buffer, void* output_buffer,
absl::Duration timeout) override;
const Executor& executor) override;

private:
MPI_Comm comm_;
Expand Down
9 changes: 5 additions & 4 deletions xla/service/cpu/collectives_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,11 @@ class CollectivesCommunicator {
absl::Duration timeout) = 0;

// Performs a reduce-scatter
virtual absl::Status ReduceScatter(
const RendezvousKey& key, ReductionKind reduction_kind,
PrimitiveType element_type, size_t chunk_elems, const void* input_buffer,
void* output_buffer, absl::Duration timeout) = 0;
virtual absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
ReductionKind reduction_kind,
const Executor& executor) = 0;
};

class CollectivesInterface {
Expand Down
15 changes: 12 additions & 3 deletions xla/service/cpu/cpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "xla/executable_run_options.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/layout_util.h"
#include "xla/primitive_util.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/computation_placer.h"
#include "xla/service/cpu/collectives_interface.h"
Expand Down Expand Up @@ -449,10 +450,18 @@ void ReduceScatterImpl(const ExecutableRunOptions* run_options,

auto communicator =
collectives->GetCommunicator(rendezvous_key.global_devices, rank).value();

auto dtype = static_cast<PrimitiveType>(element_type);

se::DeviceMemoryBase input_buffer_data(input_buffer,
primitive_util::ByteWidth(dtype));
se::DeviceMemoryBase output_buffer_data(output_buffer,
primitive_util::ByteWidth(dtype));

CpuCollectives::Executor executor(rendezvous_key, DefaultCollectiveTimeout());
TF_CHECK_OK(communicator->ReduceScatter(
rendezvous_key, static_cast<ReductionKind>(reduction_kind),
static_cast<PrimitiveType>(element_type), chunk_elems, input_buffer,
output_buffer, DefaultCollectiveTimeout()));
input_buffer_data, output_buffer_data, dtype, chunk_elems,
static_cast<ReductionKind>(reduction_kind), executor));
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY
Expand Down
17 changes: 10 additions & 7 deletions xla/service/cpu/in_process_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -533,15 +533,18 @@ absl::Status InProcessCollectivesCommunicator::AllGather(
}

absl::Status InProcessCollectivesCommunicator::ReduceScatter(
const RendezvousKey& key, ReductionKind reduction_kind,
PrimitiveType element_type, size_t chunk_elems, const void* input_buffer,
void* output_buffer, absl::Duration timeout) {
se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count, ReductionKind reduction_kind,
const Executor& executor) {
TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor));
const RendezvousKey& key = cpu_executor->rendezvous_key();

ReduceScatterParticipantData participant(key, rank_);
participant.element_type = element_type;
participant.element_type = dtype;
participant.reduction_kind = reduction_kind;
participant.chunk_elems = chunk_elems;
participant.source_buffer = input_buffer;
participant.destination_buffer = output_buffer;
participant.chunk_elems = count;
participant.source_buffer = send_buffer.opaque();
participant.destination_buffer = recv_buffer.opaque();
auto make_cpu_rendezvous = [](const RendezvousKey& k) {
return std::make_unique<CpuReduceScatterRendezvous>(k);
};
Expand Down
8 changes: 4 additions & 4 deletions xla/service/cpu/in_process_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ class InProcessCollectivesCommunicator : public CollectivesCommunicator {
const void* input_buffer, void* output_buffer,
absl::Duration timeout) override;

absl::Status ReduceScatter(const RendezvousKey& key,
absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
ReductionKind reduction_kind,
PrimitiveType element_type, size_t chunk_elems,
const void* input_buffer, void* output_buffer,
absl::Duration timeout) override;
const Executor& executor) override;

private:
InProcessCollectivesState* state_;
Expand Down

0 comments on commit a2316c3

Please sign in to comment.