Skip to content

Commit

Permalink
[xla:collectives] Add backends/gpu/collectives:nccl_communicator
Browse files Browse the repository at this point in the history
NCCL implementation detail will have private visibility, and for all external users (Thunks etc.) we'll export it via public header that uses xla/core/collectives APIs.

PiperOrigin-RevId: 699256314
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Nov 22, 2024
1 parent 218dfb0 commit 9cb47c3
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 0 deletions.
46 changes: 46 additions & 0 deletions xla/backends/gpu/collectives/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
load("@tsl//tsl/platform:rules_cc.bzl", "cc_library")
load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured")
load("//xla/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [":friends"],
licenses = ["notice"],
)

package_group(
name = "friends",
includes = [
"//xla:friends",
],
)

cc_library(
name = "nccl_errors",
hdrs = if_gpu_is_configured(["nccl_errors.h"]),
visibility = ["//visibility:private"],
deps = [
"//xla:util",
"@com_google_absl//absl/strings:str_format",
"@tsl//tsl/platform:logging",
],
)

cc_library(
name = "nccl_communicator",
srcs = if_gpu_is_configured(["nccl_communicator.cc"]),
hdrs = if_gpu_is_configured(["nccl_communicator.h"]),
visibility = ["//visibility:private"],
deps = [
":nccl_errors",
"//xla/core/collectives:communicator",
"@com_google_absl//absl/strings:str_format",
"@tsl//tsl/platform:logging",
] + if_cuda_is_configured([
"@local_config_nccl//:nccl",
]) + if_rocm_is_configured([
"@local_config_rocm//rocm:rocm_headers",
"@local_config_rocm//rocm:rccl",
]),
)
48 changes: 48 additions & 0 deletions xla/backends/gpu/collectives/nccl_communicator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/backends/gpu/collectives/nccl_communicator.h"

#include <string>

#include "absl/strings/str_format.h"
#include "xla/backends/gpu/collectives/nccl_errors.h"
#include "tsl/platform/logging.h"

#if TENSORFLOW_USE_ROCM
#include "rocm/rocm_config.h"
#if (TF_ROCM_VERSION >= 50200)
#include "rocm/include/rccl/rccl.h"
#else
#include "rocm/include/rccl.h"
#endif // TF_ROCM_VERSION >= 50200
#else
#include "third_party/nccl/nccl.h"
#endif // TENSORFLOW_USE_ROCM

namespace xla::gpu {

NcclCommunicator::NcclCommunicator(ncclComm_t comm) : comm_(comm) {}

NcclCommunicator::~NcclCommunicator() {
VLOG(1) << "Destroy " << *this;
XLA_NCCL_LOG_IF_ERROR(ncclCommDestroy(comm_));
}

std::string NcclCommunicator::ToString() const {
return absl::StrFormat("NccCommunicator(ncclComm_t=%p)", comm_);
}

} // namespace xla::gpu
52 changes: 52 additions & 0 deletions xla/backends/gpu/collectives/nccl_communicator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_BACKENDS_GPU_COLLECTIVES_NCCL_COMMUNICATOR_H_
#define XLA_BACKENDS_GPU_COLLECTIVES_NCCL_COMMUNICATOR_H_

#include <string>

#include "xla/core/collectives/communicator.h"

#if TENSORFLOW_USE_ROCM
#include "rocm/rocm_config.h"
#if (TF_ROCM_VERSION >= 50200)
#include "rocm/include/rccl/rccl.h"
#else
#include "rocm/include/rccl.h"
#endif // TF_ROCM_VERSION >= 50200
#else
#include "third_party/nccl/nccl.h"
#endif // TENSORFLOW_USE_ROCM

namespace xla::gpu {

// XLA collectives communicator wrapping an NCCL communicator.
class NcclCommunicator : public Communicator {
public:
explicit NcclCommunicator(ncclComm_t comm);
~NcclCommunicator() override;

std::string ToString() const final;

ncclComm_t comm() const { return comm_; }

private:
ncclComm_t comm_;
};

} // namespace xla::gpu

#endif // XLA_BACKENDS_GPU_COLLECTIVES_NCCL_COMMUNICATOR_H_
54 changes: 54 additions & 0 deletions xla/backends/gpu/collectives/nccl_errors.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_BACKENDS_GPU_COLLECTIVES_NCCL_ERRORS_H_
#define XLA_BACKENDS_GPU_COLLECTIVES_NCCL_ERRORS_H_

#include "absl/strings/str_format.h" // IWYU pragma: keep
#include "xla/util.h" // IWYU pragma: keep
#include "tsl/platform/logging.h" // IWYU pragma: keep

//===----------------------------------------------------------------------===//
// Collection of helper macros for handling NCCL errors.
//===----------------------------------------------------------------------===//

#define XLA_NCCL_STATUS(expr) \
[](ncclResult_t s, std::string_view str) -> absl::Status { \
if (s == ncclSuccess) return absl::OkStatus(); \
return xla::Internal( \
"NCCL operation %s failed: %s. Last NCCL warning(error) log " \
"entry (may be unrelated) '%s'.", \
str, ncclGetErrorString(s), ncclGetLastError(nullptr)); \
}(expr, #expr)

#define XLA_NCCL_RETURN_IF_ERROR(expr) \
do { \
absl::Status s = XLA_NCCL_STATUS(expr); \
if (!s.ok()) { \
return s; \
} \
} while (0)

#define XLA_NCCL_LOG_IF_ERROR(expr) \
do { \
absl::Status s = XLA_NCCL_STATUS(expr); \
if (!s.ok()) { \
LOG(ERROR) << s.ToString(); \
} \
} while (0)

#define XLA_NCCL_CHECK(expr) CHECK(XLA_NCCL_STATUS(expr).ok())

#endif // XLA_BACKENDS_GPU_COLLECTIVES_NCCL_ERRORS_H_
5 changes: 5 additions & 0 deletions xla/core/collectives/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef XLA_CORE_COLLECTIVES_COMMUNICATOR_H_
#define XLA_CORE_COLLECTIVES_COMMUNICATOR_H_

#include <ostream>
#include <string>

namespace xla {
Expand All @@ -28,6 +29,10 @@ class Communicator {
virtual std::string ToString() const = 0;
};

inline std::ostream& operator<<(std::ostream& os, const Communicator& comm) {
return os << comm.ToString();
}

} // namespace xla

#endif // XLA_CORE_COLLECTIVES_COMMUNICATOR_H_

0 comments on commit 9cb47c3

Please sign in to comment.