diff --git a/xla/backends/gpu/collectives/BUILD b/xla/backends/gpu/collectives/BUILD new file mode 100644 index 0000000000000..3bc94862e264c --- /dev/null +++ b/xla/backends/gpu/collectives/BUILD @@ -0,0 +1,46 @@ +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") +load("//xla/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "nccl_errors", + hdrs = if_gpu_is_configured(["nccl_errors.h"]), + visibility = ["//visibility:private"], + deps = [ + "//xla:util", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:logging", + ], +) + +cc_library( + name = "nccl_communicator", + srcs = if_gpu_is_configured(["nccl_communicator.cc"]), + hdrs = if_gpu_is_configured(["nccl_communicator.h"]), + visibility = ["//visibility:private"], + deps = [ + ":nccl_errors", + "//xla/core/collectives:communicator", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:logging", + ] + if_cuda_is_configured([ + "@local_config_nccl//:nccl", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + "@local_config_rocm//rocm:rccl", + ]), +) diff --git a/xla/backends/gpu/collectives/nccl_communicator.cc b/xla/backends/gpu/collectives/nccl_communicator.cc new file mode 100644 index 0000000000000..da2d6c23e9679 --- /dev/null +++ b/xla/backends/gpu/collectives/nccl_communicator.cc @@ -0,0 +1,48 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/gpu/collectives/nccl_communicator.h" + +#include + +#include "absl/strings/str_format.h" +#include "xla/backends/gpu/collectives/nccl_errors.h" +#include "tsl/platform/logging.h" + +#if TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#if (TF_ROCM_VERSION >= 50200) +#include "rocm/include/rccl/rccl.h" +#else +#include "rocm/include/rccl.h" +#endif // TF_ROCM_VERSION >= 50200 +#else +#include "third_party/nccl/nccl.h" +#endif // TENSORFLOW_USE_ROCM + +namespace xla::gpu { + +NcclCommunicator::NcclCommunicator(ncclComm_t comm) : comm_(comm) {} + +NcclCommunicator::~NcclCommunicator() { + VLOG(1) << "Destroy " << *this; + XLA_NCCL_LOG_IF_ERROR(ncclCommDestroy(comm_)); +} + +std::string NcclCommunicator::ToString() const { + return absl::StrFormat("NccCommunicator(ncclComm_t=%p)", comm_); +} + +} // namespace xla::gpu diff --git a/xla/backends/gpu/collectives/nccl_communicator.h b/xla/backends/gpu/collectives/nccl_communicator.h new file mode 100644 index 0000000000000..aaf069a203fd2 --- /dev/null +++ b/xla/backends/gpu/collectives/nccl_communicator.h @@ -0,0 +1,52 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_GPU_COLLECTIVES_NCCL_COMMUNICATOR_H_ +#define XLA_BACKENDS_GPU_COLLECTIVES_NCCL_COMMUNICATOR_H_ + +#include + +#include "xla/core/collectives/communicator.h" + +#if TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#if (TF_ROCM_VERSION >= 50200) +#include "rocm/include/rccl/rccl.h" +#else +#include "rocm/include/rccl.h" +#endif // TF_ROCM_VERSION >= 50200 +#else +#include "third_party/nccl/nccl.h" +#endif // TENSORFLOW_USE_ROCM + +namespace xla::gpu { + +// XLA collectives communicator wrapping an NCCL communicator. +class NcclCommunicator : public Communicator { + public: + explicit NcclCommunicator(ncclComm_t comm); + ~NcclCommunicator() override; + + std::string ToString() const final; + + ncclComm_t comm() const { return comm_; } + + private: + ncclComm_t comm_; +}; + +} // namespace xla::gpu + +#endif // XLA_BACKENDS_GPU_COLLECTIVES_NCCL_COMMUNICATOR_H_ diff --git a/xla/backends/gpu/collectives/nccl_errors.h b/xla/backends/gpu/collectives/nccl_errors.h new file mode 100644 index 0000000000000..61feee68cbdc3 --- /dev/null +++ b/xla/backends/gpu/collectives/nccl_errors.h @@ -0,0 +1,54 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_GPU_COLLECTIVES_NCCL_ERRORS_H_ +#define XLA_BACKENDS_GPU_COLLECTIVES_NCCL_ERRORS_H_ + +#include "absl/strings/str_format.h" // IWYU pragma: keep +#include "xla/util.h" // IWYU pragma: keep +#include "tsl/platform/logging.h" // IWYU pragma: keep + +//===----------------------------------------------------------------------===// +// Collection of helper macros for handling NCCL errors. +//===----------------------------------------------------------------------===// + +#define XLA_NCCL_STATUS(expr) \ + [](ncclResult_t s, std::string_view str) -> absl::Status { \ + if (s == ncclSuccess) return absl::OkStatus(); \ + return xla::Internal( \ + "NCCL operation %s failed: %s. Last NCCL warning(error) log " \ + "entry (may be unrelated) '%s'.", \ + str, ncclGetErrorString(s), ncclGetLastError(nullptr)); \ + }(expr, #expr) + +#define XLA_NCCL_RETURN_IF_ERROR(expr) \ + do { \ + absl::Status s = XLA_NCCL_STATUS(expr); \ + if (!s.ok()) { \ + return s; \ + } \ + } while (0) + +#define XLA_NCCL_LOG_IF_ERROR(expr) \ + do { \ + absl::Status s = XLA_NCCL_STATUS(expr); \ + if (!s.ok()) { \ + LOG(ERROR) << s.ToString(); \ + } \ + } while (0) + +#define XLA_NCCL_CHECK(expr) CHECK(XLA_NCCL_STATUS(expr).ok()) + +#endif // XLA_BACKENDS_GPU_COLLECTIVES_NCCL_ERRORS_H_ diff --git a/xla/core/collectives/communicator.h b/xla/core/collectives/communicator.h index fd500851b5aea..eb8d76b5f2267 100644 --- a/xla/core/collectives/communicator.h +++ b/xla/core/collectives/communicator.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_CORE_COLLECTIVES_COMMUNICATOR_H_ #define XLA_CORE_COLLECTIVES_COMMUNICATOR_H_ +#include #include namespace xla { @@ -28,6 +29,10 @@ class Communicator { virtual std::string ToString() const = 0; }; +inline std::ostream& operator<<(std::ostream& os, const Communicator& comm) { + return os << comm.ToString(); +} + } // namespace xla #endif // XLA_CORE_COLLECTIVES_COMMUNICATOR_H_