Skip to content

Commit

Permalink
[xla:cpu] Migrate AllReduce to unified collectives API
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 711530846
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Jan 2, 2025
1 parent ac5a809 commit 8c345fa
Show file tree
Hide file tree
Showing 16 changed files with 197 additions and 88 deletions.
2 changes: 2 additions & 0 deletions xla/backends/cpu/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ cc_library(
"//xla/core/collectives",
"//xla/core/collectives:collectives_registry",
"//xla/core/collectives:communicator",
"//xla/service:collective_ops_utils",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/time",
"@tsl//tsl/platform:casts",
],
)
24 changes: 24 additions & 0 deletions xla/backends/cpu/collectives/cpu_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ limitations under the License.
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/time/time.h"
#include "xla/core/collectives/collectives.h"
#include "xla/core/collectives/collectives_registry.h"
#include "xla/core/collectives/communicator.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/util.h"
#include "tsl/platform/casts.h"

namespace xla::cpu {
Expand All @@ -36,4 +40,24 @@ CpuCollectives* CpuCollectives::Default() {
LOG(FATAL) << "Unsupported collectives implementation for CPU";
}

absl::StatusOr<const CpuCollectives::Device*> CpuCollectives::TryCast(
const Collectives::Device* device) {
if (auto* cpu_device = tsl::down_cast<const Device*>(device)) {
return cpu_device;
}
return InvalidArgument("Collectives device is not a CPU device");
}

absl::StatusOr<const CpuCollectives::Executor*> CpuCollectives::TryCast(
const Communicator::Executor* executor) {
if (auto* cpu_executor = tsl::down_cast<const Executor*>(executor)) {
return cpu_executor;
}
return InvalidArgument("Collectives executor is not a CPU executor");
}

CpuCollectives::Executor::Executor(RendezvousKey rendezvous_key,
absl::Duration timeout)
: rendezvous_key_(rendezvous_key), timeout_(timeout) {}

} // namespace xla::cpu
22 changes: 21 additions & 1 deletion xla/backends/cpu/collectives/cpu_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ limitations under the License.
#ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_COLLECTIVES_H_
#define XLA_BACKENDS_CPU_COLLECTIVES_CPU_COLLECTIVES_H_

#include "absl/status/statusor.h"
#include "absl/time/time.h"
#include "xla/core/collectives/collectives.h"
#include "xla/core/collectives/communicator.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/xla_data.pb.h"

namespace xla::cpu {
Expand All @@ -33,10 +36,27 @@ class CpuCollectives : public Collectives {
Device() = default;
};

// Executor allows CPU collectives clients to pass additional information to
// the collectives implementation.
class Executor : public Communicator::Executor {
public:
Executor() = default;
Executor(RendezvousKey rendezvous_key, absl::Duration timeout);

const RendezvousKey& rendezvous_key() const { return rendezvous_key_; }
const absl::Duration& timeout() const { return timeout_; }

private:
RendezvousKey rendezvous_key_;
absl::Duration timeout_;
};

// Tries to cast a Collectives::Device to a CpuCollectives::Device.
static absl::StatusOr<const Device*> TryCast(
const Collectives::Device* device);

// Tries to cast a Communicator::Executor to a CpuCollectives::Executor.
static absl::StatusOr<const Executor*> TryCast(
const Communicator::Executor* executor);
};

} // namespace xla::cpu
Expand Down
4 changes: 4 additions & 0 deletions xla/backends/cpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -458,13 +458,17 @@ cc_library(
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"//xla/service:collective_ops_utils",
"//xla/service/cpu:collectives_interface",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/platform:errors",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down
13 changes: 7 additions & 6 deletions xla/backends/cpu/runtime/all_reduce_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ limitations under the License.
#include <utility>

#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/primitive_util.h"
Expand All @@ -35,9 +37,8 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/errors.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

Expand Down Expand Up @@ -102,12 +103,12 @@ tsl::AsyncValueRef<AllReduceThunk::ExecuteEvent> AllReduceThunk::Execute(
return ExecuteWithCommunicator(
params.collective_params,
[&](const RendezvousKey& key, CollectivesCommunicator& comm) {
CpuCollectives::Executor executor(key, DefaultCollectiveTimeout());
for (int32_t i = 0; i < data.source.size(); ++i) {
const Shape& shape = destination_shape(i);
TF_RETURN_IF_ERROR(comm.AllReduce(
key, reduction_kind_, shape.element_type(),
ShapeUtil::ElementsIn(shape), data.source[i].opaque(),
data.destination[i].opaque(), DefaultCollectiveTimeout()));
data.source[i], data.destination[i], shape.element_type(),
ShapeUtil::ElementsIn(shape), reduction_kind_, executor));
}
return absl::OkStatus();
});
Expand Down
18 changes: 12 additions & 6 deletions xla/pjrt/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,13 @@ cc_library(
"//xla:status_macros",
"//xla:types",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/service:collective_ops_utils",
"//xla/service:global_device_id",
"//xla/service/cpu:collectives_interface",
"//xla/stream_executor:device_memory",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
Expand All @@ -325,21 +329,23 @@ xla_cc_test(
":gloo_kv_store",
"//xla:executable_run_options",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/pjrt/distributed:in_memory_key_value_store",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:collective_ops_utils",
"//xla/service:global_device_id",
"//xla/service/cpu:collectives_interface",
"//xla/stream_executor:device_memory",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"//xla/tsl/platform:test_benchmark",
"//xla/tsl/platform:test_main",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_benchmark",
"@tsl//tsl/platform:test_main",
] + select({
# Gloo's transport_tcp is not available on MacOS
"//xla/tsl:macos": [
Expand Down
60 changes: 33 additions & 27 deletions xla/pjrt/cpu/gloo_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,17 @@ limitations under the License.
#include "gloo/transport/device.h"
#include "gloo/transport/unbound_buffer.h"
#include "gloo/types.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/primitive_util.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/service/global_device_id.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/types.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"

namespace xla::cpu {

Expand All @@ -66,14 +68,16 @@ GlooCollectivesCommunicator::~GlooCollectivesCommunicator() = default;

template <typename T>
static absl::Status SetAllReduceOptions(ReductionKind reduction_kind,
const void* input_buffer,
void* output_buffer,
se::DeviceMemoryBase input_buffer,
se::DeviceMemoryBase output_buffer,
size_t num_elements,
gloo::AllreduceOptions& options) {
options.setInput(reinterpret_cast<T*>(const_cast<void*>(input_buffer)),
num_elements);
options.setOutput(reinterpret_cast<T*>(const_cast<void*>(output_buffer)),
num_elements);
options.setInput(
reinterpret_cast<T*>(const_cast<void*>(input_buffer.opaque())),
num_elements);
options.setOutput(
reinterpret_cast<T*>(const_cast<void*>(output_buffer.opaque())),
num_elements);

using ReductionFn = void (*)(void*, const void*, const void*, size_t);

Expand Down Expand Up @@ -105,75 +109,77 @@ static absl::Status SetAllReduceOptions(ReductionKind reduction_kind,
}

absl::Status GlooCollectivesCommunicator::AllReduce(
const RendezvousKey& key, ReductionKind reduction_kind,
PrimitiveType element_type, size_t num_elements, const void* input_buffer,
void* output_buffer, absl::Duration timeout) {
se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count, ReductionKind reduction_kind,
const Executor& executor) {
TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor));

gloo::AllreduceOptions options(context_);
// TODO(phawkins): how to do tags?
// options.setTag(tag);
switch (element_type) {
switch (dtype) {
case S8:
TF_RETURN_IF_ERROR(SetAllReduceOptions<int8_t>(
reduction_kind, input_buffer, output_buffer, num_elements, options));
reduction_kind, send_buffer, recv_buffer, count, options));
break;
case PRED:
case U8:
TF_RETURN_IF_ERROR(SetAllReduceOptions<uint8_t>(
reduction_kind, input_buffer, output_buffer, num_elements, options));
reduction_kind, send_buffer, recv_buffer, count, options));
break;
case S16:
TF_RETURN_IF_ERROR(SetAllReduceOptions<int16_t>(
reduction_kind, input_buffer, output_buffer, num_elements, options));
reduction_kind, send_buffer, recv_buffer, count, options));
break;
case U16:
TF_RETURN_IF_ERROR(SetAllReduceOptions<uint16_t>(
reduction_kind, input_buffer, output_buffer, num_elements, options));
reduction_kind, send_buffer, recv_buffer, count, options));
break;
case S32:
TF_RETURN_IF_ERROR(SetAllReduceOptions<int32_t>(
reduction_kind, input_buffer, output_buffer, num_elements, options));
reduction_kind, send_buffer, recv_buffer, count, options));
break;
case U32:
TF_RETURN_IF_ERROR(SetAllReduceOptions<uint32_t>(
reduction_kind, input_buffer, output_buffer, num_elements, options));
reduction_kind, send_buffer, recv_buffer, count, options));
break;
case S64:
TF_RETURN_IF_ERROR(SetAllReduceOptions<int64_t>(
reduction_kind, input_buffer, output_buffer, num_elements, options));
reduction_kind, send_buffer, recv_buffer, count, options));
break;
case U64:
TF_RETURN_IF_ERROR(SetAllReduceOptions<uint64_t>(
reduction_kind, input_buffer, output_buffer, num_elements, options));
reduction_kind, send_buffer, recv_buffer, count, options));
break;
case F16:
TF_RETURN_IF_ERROR(SetAllReduceOptions<gloo::float16>(
reduction_kind, input_buffer, output_buffer, num_elements, options));
reduction_kind, send_buffer, recv_buffer, count, options));
break;
case BF16:
TF_RETURN_IF_ERROR(SetAllReduceOptions<bfloat16>(
reduction_kind, input_buffer, output_buffer, num_elements, options));
reduction_kind, send_buffer, recv_buffer, count, options));
break;
case F32:
TF_RETURN_IF_ERROR(SetAllReduceOptions<float>(
reduction_kind, input_buffer, output_buffer, num_elements, options));
reduction_kind, send_buffer, recv_buffer, count, options));
break;
case F64:
TF_RETURN_IF_ERROR(SetAllReduceOptions<double>(
reduction_kind, input_buffer, output_buffer, num_elements, options));
reduction_kind, send_buffer, recv_buffer, count, options));
break;
case C64:
TF_RETURN_IF_ERROR(SetAllReduceOptions<std::complex<float>>(
reduction_kind, input_buffer, output_buffer, num_elements, options));
reduction_kind, send_buffer, recv_buffer, count, options));
break;
case C128:
TF_RETURN_IF_ERROR(SetAllReduceOptions<std::complex<double>>(
reduction_kind, input_buffer, output_buffer, num_elements, options));
reduction_kind, send_buffer, recv_buffer, count, options));
break;
default:
return absl::InvalidArgumentError("Unknown datatype in allreduce");
}
options.setAlgorithm(gloo::AllreduceOptions::Algorithm::RING);
options.setTimeout(absl::ToChronoMilliseconds(timeout));
options.setTimeout(absl::ToChronoMilliseconds(cpu_executor->timeout()));

try {
gloo::allreduce(options);
Expand Down
8 changes: 4 additions & 4 deletions xla/pjrt/cpu/gloo_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ class GlooCollectivesCommunicator : public CollectivesCommunicator {
explicit GlooCollectivesCommunicator(std::shared_ptr<gloo::Context> context);
~GlooCollectivesCommunicator() override;

absl::Status AllReduce(const RendezvousKey& key, ReductionKind reduction_kind,
PrimitiveType element_type, size_t num_elements,
const void* input_buffer, void* output_buffer,
absl::Duration timeout) override;
absl::Status AllReduce(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer, PrimitiveType dtype,
size_t count, ReductionKind reduction_kind,
const Executor& executor) override;
absl::Status CollectivePermute(const RendezvousKey& key, size_t num_bytes,
std::optional<int> source_rank,
absl::Span<int const> target_ranks,
Expand Down
Loading

0 comments on commit 8c345fa

Please sign in to comment.