Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed some issues around compiling on Windows. #15444

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions xla/backends/profiler/gpu/cupti_buffer_events.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,18 @@ void AddGraphTraceActivityEvent(CuptiEventCollectorDelegate &collector,
AnnotationMap::AnnotationInfo info = collector.annotation_map.LookUp(
graph_trace->deviceId, graph_trace->correlationId);
collector.receive(CuptiTracerEvent{
.type = CuptiTracerEventType::CudaGraph,
.source = CuptiTracerEventSource::Activity,
.name = absl::StrCat("CudaGraphExec:", graph_trace->graphId),
.annotation = info.annotation,
.nvtx_range = info.nvtx_range,
.start_time_ns = graph_trace->start,
.end_time_ns = graph_trace->end,
.device_id = graph_trace->deviceId,
.correlation_id = graph_trace->correlationId,
.context_id = graph_trace->contextId,
.stream_id = graph_trace->streamId,
.graph_id = graph_trace->graphId,
/* .type = */ CuptiTracerEventType::CudaGraph,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XLA is configured to build using C++ 17. However, this is a C++ 20 feature, resulting in the following error when trying to compile on Windows:

error C7555: use of designated initializers requires at least '/std:c++20'

/* .source = */ CuptiTracerEventSource::Activity,
/* .name = */ absl::StrCat("CudaGraphExec:", graph_trace->graphId),
/* .annotation = */ info.annotation,
/* .nvtx_range = */ info.nvtx_range,
/* .start_time_ns = */ graph_trace->start,
/* .end_time_ns = */ graph_trace->end,
/* .device_id = */ graph_trace->deviceId,
/* .correlation_id = */ graph_trace->correlationId,
/* .context_id = */ graph_trace->contextId,
/* .stream_id = */ graph_trace->streamId,
/* .graph_id = */ graph_trace->graphId,
});
}

Expand Down
2 changes: 1 addition & 1 deletion xla/backends/profiler/gpu/cupti_buffer_events.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct MemcpyDetails {
int8_t dst_mem_kind;

// ID of the hardware channel on which this operation ran.
uint32_t channel_id = -1;
uint32_t channel_id = static_cast<uint32_t>(-1);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This resulted in an implicit type narrowing error (I believe it was C2397). The explicit static cast fixes it.

// CUpti_ChannelType of the channel above.
int8_t channel_type = 0; // CUPTI_CHANNEL_TYPE_INVALID
};
Expand Down
2 changes: 2 additions & 0 deletions xla/hlo/evaluator/hlo_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,9 @@ std::optional<DynamicOrStaticInteger> EvaluateWhileLoopParamInitValue(

namespace internal {

#if !defined(_MSC_VER)
constexpr absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl";
#endif

std::optional<EvalErrorDetail> ParseEvalErrorDetail(const absl::Status& error) {
auto error_detail = error.GetPayload(kEvalErrorDetailUrl);
Expand Down
4 changes: 4 additions & 0 deletions xla/hlo/evaluator/hlo_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,11 @@ enum class EvalErrorDetail : uint32_t {
kDynamicValueDependence = 0,
};

#if defined(_MSC_VER)
extern const absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl";
#else
extern const absl::string_view kEvalErrorDetailUrl;
#endif

std::optional<EvalErrorDetail> ParseEvalErrorDetail(const absl::Status& error);

Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2129,7 +2129,7 @@ PJRT_Error* PJRT_Layouts_MemoryLayout_Serialize(
PJRT_Layouts_MemoryLayout_Serialize_Args_STRUCT_SIZE, args->struct_size));

PJRT_Layouts_SerializedLayout* s_layout = new PJRT_Layouts_SerializedLayout{
.serialized = args->layout->layout->Serialize()};
/* .serialized = */ args->layout->layout->Serialize()};
args->serialized_layout = s_layout;
args->serialized_bytes = s_layout->serialized.data();
args->serialized_bytes_size = s_layout->serialized.size();
Expand Down
14 changes: 8 additions & 6 deletions xla/pjrt/gpu/se_gpu_pjrt_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,15 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options,
#endif
}

STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, {
PjRtRegisterCompiler(
#if TENSORFLOW_USE_ROCM
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nested macro seems to not be supported by MSVC. Pushing the inner ifdef outside the other macro seems to work and doesn't change the behavior/functionality of the code here.

RocmName(),
STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, {
PjRtRegisterCompiler(RocmName(),
std::make_unique<StreamExecutorGpuCompiler>());
});
#else
CudaName(),
#endif
std::make_unique<StreamExecutorGpuCompiler>());
STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, {
PjRtRegisterCompiler(CudaName(),
std::make_unique<StreamExecutorGpuCompiler>());
});
#endif
} // namespace xla
8 changes: 4 additions & 4 deletions xla/service/cpu/runtime/conv_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void EigenConv2DImpl(
Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation,
Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation,
Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count,
std::optional<std::function<void()>> done_callback = std::nullopt) {
Copy link
Contributor Author

@eaplatanios eaplatanios Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This resulted in this error:

error C2765: 'tensorflow::xla::EigenConv2DImpl': an explicit specialization or instantiation of a function template cannot have any default arguments

I just removed a couple default arguments that were causing this error and propagated them at call sites where they were missing.

std::optional<std::function<void()>> done_callback) {
const Eigen::TensorMap<Eigen::Tensor<const ScalarType, 4, Eigen::RowMajor>,
Eigen::Aligned>
input(lhs, input_batch, input_x, input_y, input_channels);
Expand Down Expand Up @@ -129,7 +129,7 @@ void EigenConv3DImpl(
Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation,
Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation,
Eigen::Index feature_group_count,
std::optional<std::function<void()>> done_callback = std::nullopt) {
std::optional<std::function<void()>> done_callback) {
using ConstTType =
Eigen::TensorMap<Eigen::Tensor<const ScalarType, 5, Eigen::RowMajor>,
Eigen::Aligned>;
Expand Down Expand Up @@ -223,7 +223,7 @@ void EigenConv3DImpl(
Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, \
Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, \
Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count, \
std::optional<std::function<void()>> done_callback = std::nullopt)
std::optional<std::function<void()>> done_callback)

CONV2D_EXTERN_TEMPLATE(Eigen::DefaultDevice, Eigen::half);
CONV2D_EXTERN_TEMPLATE(Eigen::DefaultDevice, float);
Expand All @@ -249,7 +249,7 @@ CONV2D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, float);
Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, \
Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, \
Eigen::Index feature_group_count, \
std::optional<std::function<void()>> done_callback = std::nullopt)
std::optional<std::function<void()>> done_callback)

CONV3D_EXTERN_TEMPLATE(Eigen::DefaultDevice, Eigen::half);
CONV3D_EXTERN_TEMPLATE(Eigen::DefaultDevice, float);
Expand Down
6 changes: 4 additions & 2 deletions xla/service/cpu/runtime_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "xla/service/cpu/runtime_conv2d.h"

#include <optional>

#define EIGEN_USE_THREADS

#include "absl/base/dynamic_annotations.h"
Expand All @@ -41,7 +43,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF32(
kernel_channels, kernel_filters, output_rows, output_cols, row_stride,
col_stride, padding_top, padding_bottom, padding_left, padding_right,
lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation,
feature_group_count);
feature_group_count, std::nullopt);
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF16(
Expand All @@ -63,5 +65,5 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF16(
kernel_channels, kernel_filters, output_rows, output_cols, row_stride,
col_stride, padding_top, padding_bottom, padding_left, padding_right,
lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation,
feature_group_count);
feature_group_count, std::nullopt);
}
6 changes: 4 additions & 2 deletions xla/service/cpu/runtime_conv3d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "xla/service/cpu/runtime_conv3d.h"

#include <optional>

#define EIGEN_USE_THREADS

#include "absl/base/dynamic_annotations.h"
Expand Down Expand Up @@ -44,7 +46,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF32(
y_stride, z_stride, padding_x_before, padding_x_after, padding_y_before,
padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation,
lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation,
rhs_z_dilation, feature_group_count);
rhs_z_dilation, feature_group_count, std::nullopt);
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF16(
Expand All @@ -69,5 +71,5 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF16(
y_stride, z_stride, padding_x_before, padding_x_after, padding_y_before,
padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation,
lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation,
rhs_z_dilation, feature_group_count);
rhs_z_dilation, feature_group_count, std::nullopt);
}
6 changes: 4 additions & 2 deletions xla/service/cpu/runtime_single_threaded_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "xla/service/cpu/runtime_single_threaded_conv2d.h"

#include <optional>

#include "absl/base/dynamic_annotations.h"
#include "xla/service/cpu/runtime/conv_impl.h"

Expand All @@ -35,7 +37,7 @@ __xla_cpu_runtime_EigenSingleThreadedConv2DF16(
kernel_filters, output_rows, output_cols, row_stride, col_stride,
padding_top, padding_bottom, padding_left, padding_right,
lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation,
feature_group_count);
feature_group_count, std::nullopt);
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void
Expand All @@ -55,5 +57,5 @@ __xla_cpu_runtime_EigenSingleThreadedConv2DF32(
kernel_filters, output_rows, output_cols, row_stride, col_stride,
padding_top, padding_bottom, padding_left, padding_right,
lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation,
feature_group_count);
feature_group_count, std::nullopt);
}
6 changes: 4 additions & 2 deletions xla/service/cpu/runtime_single_threaded_conv3d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "xla/service/cpu/runtime_single_threaded_conv3d.h"

#include <optional>

#include "absl/base/dynamic_annotations.h"
#include "xla/service/cpu/runtime/conv_impl.h"

Expand All @@ -38,7 +40,7 @@ __xla_cpu_runtime_EigenSingleThreadedConv3DF32(
z_stride, padding_x_before, padding_x_after, padding_y_before,
padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation,
lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation,
rhs_z_dilation, feature_group_count);
rhs_z_dilation, feature_group_count, std::nullopt);
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void
Expand All @@ -61,5 +63,5 @@ __xla_cpu_runtime_EigenSingleThreadedConv3DF16(
z_stride, padding_x_before, padding_x_after, padding_y_before,
padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation,
lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation,
rhs_z_dilation, feature_group_count);
rhs_z_dilation, feature_group_count, std::nullopt);
}
12 changes: 6 additions & 6 deletions xla/service/gpu/fusions/mlir/computation_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,12 @@ PartitionedComputation::PartitionedComputation(
absl::StrJoin(roots, "_", [](std::string* out, const auto* root) {
absl::StrAppend(out, root->name());
})));
subgraphs_.push_back(
Subgraph{.name = std::move(name),
.instructions = {instructions.begin(), instructions.end()},
.roots = std::move(roots),
.index_ranges = std::move(ranges),
.root_indexing = std::move(root_indexing)});
subgraphs_.push_back(Subgraph{
/* .name = */ std::move(name),
/* .instructions = */ {instructions.begin(), instructions.end()},
/* .roots = */ std::move(roots),
/* .index_ranges = */ std::move(ranges),
/* .root_indexing = */ std::move(root_indexing)});
}

for (const auto& subgraph : subgraphs_) {
Expand Down
3 changes: 1 addition & 2 deletions xla/service/gpu/fusions/mlir/erase_dead_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ class EraseDeadFunctionsPass

} // namespace

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateEraseDeadFunctionsPass() {
std::unique_ptr<mlir::Pass> CreateEraseDeadFunctionsPass() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed this after the approval to resolve a linking error. The header uses mlir::Pass and MSVC appears to be sensitive to that resulting in mismatching symbol names that fail to link.

return std::make_unique<EraseDeadFunctionsPass>();
}

Expand Down
3 changes: 1 addition & 2 deletions xla/service/gpu/fusions/mlir/merge_pointers_to_same_slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ void MergePointersToSameSlicePass::runOnOperation() {

} // namespace

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateMergePointersToSameSlicePass() {
std::unique_ptr<mlir::Pass> CreateMergePointersToSameSlicePass() {
return std::make_unique<MergePointersToSameSlicePass>();
}

Expand Down
3 changes: 1 addition & 2 deletions xla/service/gpu/fusions/mlir/optimize_loops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,7 @@ class OptimizeLoopsPass

} // namespace

std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateOptimizeLoopsPass() {
std::unique_ptr<mlir::Pass> CreateOptimizeLoopsPass() {
return std::make_unique<OptimizeLoopsPass>();
}

Expand Down
3 changes: 1 addition & 2 deletions xla/service/gpu/fusions/mlir/unswitch_loops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ void UnswitchLoopsPass::runOnOperation() {

} // namespace

std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateUnswitchLoopsPass() {
std::unique_ptr<mlir::Pass> CreateUnswitchLoopsPass() {
return std::make_unique<UnswitchLoopsPass>();
}

Expand Down
3 changes: 1 addition & 2 deletions xla/service/gpu/fusions/mlir/vectorize_loads_stores.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,7 @@ class VectorizeLoadsAndStoresPass

} // namespace

std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateVectorizeLoadsAndStoresPass() {
std::unique_ptr<mlir::Pass> CreateVectorizeLoadsAndStoresPass() {
return std::make_unique<VectorizeLoadsAndStoresPass>();
}

Expand Down
25 changes: 20 additions & 5 deletions xla/service/gpu/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,10 @@ cc_library(
cuda_library(
name = "cutlass_gemm_adaptor",
hdrs = if_cuda_is_configured(["cutlass_gemm_adaptor.cu.h"]),
copts = ["-Wno-unknown-attributes"], # __grid_constant__ is not supported by clang
copts = select({
"@xla//xla/tsl:windows": [],
"//conditions:default": ["-Wno-unknown-attributes"], # __grid_constant__ is not supported by clang
}),
deps = if_cuda_is_configured([
":cutlass_gemm",
"@cutlass_archive//:cutlass",
Expand Down Expand Up @@ -367,7 +370,10 @@ cc_library(
cuda_library(
name = "cutlass_gemm_kernel_bf16xbf16_to_bf16",
srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16.cu.cc"]),
copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"],
copts = ["-mllvm", "-unroll-threshold=100000"] + select({
"@xla//xla/tsl:windows": [],
"//conditions:default": ["-Wno-unknown-attributes"],
}),
deps = if_cuda_is_configured([
":cutlass_gemm_adaptor",
"@cutlass_archive//:cutlass",
Expand All @@ -378,7 +384,10 @@ cuda_library(
cuda_library(
name = "cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80",
srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80.cu.cc"]),
copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"],
copts = ["-mllvm", "-unroll-threshold=100000"] + select({
"@xla//xla/tsl:windows": [],
"//conditions:default": ["-Wno-unknown-attributes"],
}),
deps = if_cuda_is_configured([
":cutlass_gemm_adaptor",
"@cutlass_archive//:cutlass",
Expand All @@ -389,7 +398,10 @@ cuda_library(
cuda_library(
name = "cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90",
srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90.cu.cc"]),
copts = ["-Wno-ctad-maybe-unsupported -Wno-unknown-attributes -mllvm -unroll-threshold=100000"],
copts = ["-mllvm", "-unroll-threshold=100000"] + select({
"@xla//xla/tsl:windows": [],
"//conditions:default": ["-Wno-ctad-maybe-unsupported", "-Wno-unknown-attributes"],
}),
deps = if_cuda_is_configured([
":cutlass_gemm_adaptor",
":cutlass_gemm_epilogue",
Expand All @@ -401,7 +413,10 @@ cuda_library(
cuda_library(
name = "cutlass_gemm_kernel_f32xf32_to_f32",
srcs = if_cuda_is_configured(["cutlass_gemm_kernel_f32xf32_to_f32.cu.cc"]),
copts = ["-Wno-unknown-attributes"],
copts = select({
"@xla//xla/tsl:windows": [],
"//conditions:default": ["-Wno-unknown-attributes"],
}),
deps = if_cuda_is_configured([
":cutlass_gemm_adaptor",
"@cutlass_archive//:cutlass",
Expand Down
5 changes: 3 additions & 2 deletions xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,9 @@ namespace adaptor_3x {
template <typename Tag>
static std::optional<Dim3> ClusterDim() {
typename Traits<Tag>::Kernel::DispatchPolicy::ClusterShape cluster;
return Dim3{cute::get<0>(cluster), cute::get<1>(cluster),
cute::get<2>(cluster)};
return Dim3{static_cast<uint32_t>(cute::get<0>(cluster)),
static_cast<uint32_t>(cute::get<1>(cluster)),
static_cast<uint32_t>(cute::get<2>(cluster))};
}

template <typename Tag>
Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ KernelArgsPacking ArgsPacking(int32_t m, int32_t n, int32_t k,
// object constructed in the storage. For now we ignore it, and it's textbook
// definition of UB, but for CUTLASS kernels we use today it's perfectly safe.
struct Params {
#if defined(_MSC_VER)
alignas(64) std::byte storage[1024];
#else
alignas(128) std::byte storage[1024];
#endif
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This results in the following error on Windows:

error C2719: '<args_0>': formal parameter with requested alignment of 128 won't be aligned

cc @dimvar who previously made the change from 64 to 128.

};

return [=](const se::Kernel& kernel, const se::KernelArgs& args) -> Packed {
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/model/gpu_collective_performance_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ float GpuPerformanceWithCollectiveModel::GetNvlinkBw(
}

/*static*/ bool GpuPerformanceWithCollectiveModel::InitNvml() {
#if GOOGLE_CUDA
#if GOOGLE_CUDA && defined(PLATFORM_POSIX)
void* libhandle = dlopen("libnvidia-ml.so.1", RTLD_NOW);
CHECK(libhandle != nullptr) << "Failed to open libnvidia-ml.so.1";

Expand Down
Loading
Loading