Skip to content

Commit

Permalink
[xla:collectives] Migrate Broadcast to type-safe RankId to identify b…
Browse files Browse the repository at this point in the history
…roadcast root

PiperOrigin-RevId: 712110760
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Jan 4, 2025
1 parent a2316c3 commit 51e5004
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 7 deletions.
6 changes: 3 additions & 3 deletions xla/backends/gpu/collectives/nccl_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ absl::Status NcclCommunicator::AllReduce(
absl::Status NcclCommunicator::Broadcast(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
size_t root,
RankId root,
const Executor& executor) {
TF_ASSIGN_OR_RETURN(se::Stream * stream, ToStream(executor));

Expand All @@ -241,13 +241,13 @@ absl::Status NcclCommunicator::Broadcast(se::DeviceMemoryBase send_buffer,
"stream=%p",
stream->parent()->device_ordinal(), send_buffer.opaque(),
recv_buffer.opaque(), primitive_util::LowercasePrimitiveTypeName(dtype),
count, root, comm_, stream);
count, root.value(), comm_, stream);

TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false));

return XLA_NCCL_STATUS(ncclBroadcast(
send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count),
nccl_dtype, root, comm_, se::gpu::AsGpuStreamValue(stream)));
nccl_dtype, root.value(), comm_, se::gpu::AsGpuStreamValue(stream)));
}

absl::Status NcclCommunicator::ReduceScatter(se::DeviceMemoryBase send_buffer,
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/gpu/collectives/nccl_communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class NcclCommunicator : public Communicator {

absl::Status Broadcast(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer, PrimitiveType dtype,
size_t count, size_t root,
size_t count, RankId root,
const Executor& executor) final;

absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer,
Expand Down
2 changes: 1 addition & 1 deletion xla/core/collectives/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class Communicator {
// all other devices.
virtual absl::Status Broadcast(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count, size_t root,
PrimitiveType dtype, size_t count, RankId root,
const Executor& executor) = 0;

// Reduce data in `send_buff` from all devices using the `reduction_kind`
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/collectives:gpu_collectives",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/hlo/ir:hlo",
"//xla/service:collective_ops_utils",
"//xla/stream_executor:device_memory",
Expand Down
5 changes: 3 additions & 2 deletions xla/service/gpu/runtime/nccl_collective_broadcast_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "xla/backends/gpu/collectives/gpu_collectives.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/collective_ops_utils.h"
Expand Down Expand Up @@ -77,8 +78,8 @@ absl::Status RunCollectiveBroadcast(std::vector<DeviceBufferPair>& buffers,
TF_RETURN_IF_ERROR(comm->Broadcast(
// Always use rank 0 since we always broadcast from the first id in
// replica_groups
src_addr, dest_addr, buffer.element_type, buffer.element_count, 0,
GpuCollectives::On(stream)));
src_addr, dest_addr, buffer.element_type, buffer.element_count,
RankId(0), GpuCollectives::On(stream)));
}
return collectives->GroupEnd();
}
Expand Down

0 comments on commit 51e5004

Please sign in to comment.