diff --git a/xla/service/gpu/gpu_norm_runner.cc b/xla/service/gpu/gpu_norm_runner.cc index 5abb58af4e019..9170de6016e60 100644 --- a/xla/service/gpu/gpu_norm_runner.cc +++ b/xla/service/gpu/gpu_norm_runner.cc @@ -44,22 +44,8 @@ absl::Status RunGpuNorm(const gpu::GpuNormConfig& config, se::Stream* stream, RunNormOptions options) { se::dnn::LazyOpRunner* lazy_runner = options.norm_runner->AsNormRunner(); - std::optional> local_runner; - - TF_ASSIGN_OR_RETURN(se::dnn::NormKind kind, - GetDNNNormKindFromCudnnNormKind(config.kind)); - - se::dnn::NormOp::Config ln_config{kind, - config.epsilon, - config.x_descriptor, - config.scale_descriptor, - config.y_or_dx_descriptor, - config.bias_descriptor, - config.dy_descriptor, - config.expectation_descriptor, - config.norm_factor_descriptor, - config.dscale_descriptor, - config.dbias_descriptor}; + TF_ASSIGN_OR_RETURN(se::dnn::NormOp::Config ln_config, + config.AsDnnNormOpConfig()); TF_ASSIGN_OR_RETURN(auto* runner, lazy_runner->GetOrCreateRunner(ln_config, stream)); diff --git a/xla/service/gpu/gpu_norm_runner.h b/xla/service/gpu/gpu_norm_runner.h index 854e3c0892050..8461671e86d03 100644 --- a/xla/service/gpu/gpu_norm_runner.h +++ b/xla/service/gpu/gpu_norm_runner.h @@ -118,6 +118,22 @@ struct GpuNormConfig { return config; } + absl::StatusOr AsDnnNormOpConfig() const { + TF_ASSIGN_OR_RETURN(se::dnn::NormKind norm_kind, + GetDNNNormKindFromCudnnNormKind(kind)); + return se::dnn::NormOp::Config{norm_kind, + epsilon, + x_descriptor, + scale_descriptor, + y_or_dx_descriptor, + bias_descriptor, + dy_descriptor, + expectation_descriptor, + norm_factor_descriptor, + dscale_descriptor, + dbias_descriptor}; + } + double epsilon; CudnnNormKind kind; se::dnn::AlgorithmDesc algorithm; diff --git a/xla/service/gpu/runtime/norm_thunk.cc b/xla/service/gpu/runtime/norm_thunk.cc index d3862f7bfeac7..71c0744686e40 100644 --- a/xla/service/gpu/runtime/norm_thunk.cc +++ b/xla/service/gpu/runtime/norm_thunk.cc @@ -106,5 +106,14 @@ absl::Status NormThunk::ExecuteOnStream(const ExecuteParams& params) { return absl::OkStatus(); } +absl::Status NormThunk::Initialize(const InitializeParams& params) { + // Create the runner at initialization time to avoid hangs if we try to build + // the execution plan while a NCCL collective is running. + se::dnn::LazyOpRunner* lazy_runner = + GetOrCreateRunner(params.stream).AsNormRunner(); + TF_ASSIGN_OR_RETURN(auto ln_config, config_.AsDnnNormOpConfig()); + return lazy_runner->GetOrCreateRunner(ln_config, params.stream).status(); +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/runtime/norm_thunk.h b/xla/service/gpu/runtime/norm_thunk.h index 602d504175fb3..eca5deca3a68b 100644 --- a/xla/service/gpu/runtime/norm_thunk.h +++ b/xla/service/gpu/runtime/norm_thunk.h @@ -49,6 +49,7 @@ class NormThunk : public Thunk { NormThunk& operator=(const NormThunk&) = delete; absl::Status ExecuteOnStream(const ExecuteParams& params) override; + absl::Status Initialize(const InitializeParams& params) override; private: BufferAllocation::Slice x_buffer_;