From bbfcc6ffb2006c48ac0eac6d374150832351c42e Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 5 Oct 2023 16:46:07 -0700 Subject: [PATCH] Improve CUDA EP's GetCapability --- .../providers/cuda/cuda_execution_provider.cc | 23 ++++++++++++------- .../providers/shared_library/provider_api.h | 3 +++ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index ad892eab3b843..de01e240a06c7 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -373,7 +373,7 @@ Status CUDAExecutionProvider::OnRunStart() { // always set CUDA device when session::Run() in case it runs in a worker thread CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId())); if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { - LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; + LOGS(*GetLogger(), INFO) << "Capturing the cuda graph for this model"; GetPerThreadContext().CaptureBegin(); } return Status::OK(); @@ -2410,7 +2410,7 @@ static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node, return false; } -static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node) { +static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const logging::Logger& logger) { const auto& node_attributes = node.GetAttributes(); // Check attributes for (auto& attr : node_attributes) { @@ -2428,7 +2428,7 @@ static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node) { int rank = pads_size / 2; for (int i = 0; i < rank; i++) { if (pads.Get(i) != pads.Get(i + rank)) { - LOGS_DEFAULT(WARNING) << "Dropping the ConvTranspose node: " << node.Name() + LOGS(logger, WARNING) << "Dropping the ConvTranspose node: " << node.Name() << " to CPU because it requires asymmetric padding which the CUDA EP" << " currently does not support"; return true; @@ -2450,7 +2450,7 @@ static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node) { // symmetric padding. // TODO: Remove this after we have supported asymmetric padding in the CUDA ConvTranspose kernel if (auto_pad_attr == "SAME_UPPER" || auto_pad_attr == "SAME_LOWER") { - LOGS_DEFAULT(WARNING) << "Dropping the ConvTranspose node: " << node.Name() + LOGS(logger, WARNING) << "Dropping the ConvTranspose node: " << node.Name() << " to CPU because it uses the auto_pad attribute which may lead to asymmetric padding which" << " the CUDA EP currently does not support"; return true; @@ -2487,6 +2487,9 @@ std::vector> CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup) const { InlinedVector candidates; + // A subset of the above vector. A subset of the tentative_nodes might be moved to CPU. + InlinedVector tentative_nodes; + const logging::Logger& logger = *GetLogger(); for (auto& node_index : graph.GetNodesInTopologicalOrder()) { const auto* p_node = graph.GetNode(node_index); if (p_node == nullptr) @@ -2494,13 +2497,16 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const auto& node = *p_node; if (!node.GetExecutionProviderType().empty()) { + if (node.GetExecutionProviderType() == kCudaExecutionProvider) { + candidates.push_back(node.Index()); + } continue; } const KernelCreateInfo* cuda_kernel_def = kernel_lookup.LookUpKernel(node); // none of the provided registries has a CUDA kernel for this node if (cuda_kernel_def == nullptr) { - LOGS_DEFAULT(INFO) << "CUDA kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name(); + LOGS(logger, INFO) << "CUDA kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name(); continue; } @@ -2520,7 +2526,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, not_supported = RNNNeedFallbackToCPU(node, activations_supported, node.OpType()); force_inside = !not_supported; } else if ("ConvTranspose" == node.OpType()) { - not_supported = ConvTransposeNeedFallbackToCPU(node); + not_supported = ConvTransposeNeedFallbackToCPU(node, logger); force_inside = !not_supported; } else if ("Cast" == node.OpType()) { not_supported = CastNeedFallbackToCPU(node); @@ -2529,9 +2535,10 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, if (!force_inside && not_supported) { if (not_supported) { - LOGS_DEFAULT(WARNING) << "CUDA kernel not supported. Fallback to CPU execution provider for Op type: " << node.OpType() << " node name: " << node.Name(); + LOGS(logger, WARNING) << "CUDA kernel not supported. Fallback to CPU execution provider for Op type: " << node.OpType() << " node name: " << node.Name(); } } else { + tentative_nodes.push_back(node.Index()); candidates.push_back(node.Index()); } } @@ -2539,7 +2546,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // For CUDA EP, exclude the subgraph that is preferred to be placed in CPU // These are usually shape related computation subgraphs // Following logic can be extended for other EPs - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, candidates); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 0d7da46142170..85599fab808b3 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -350,6 +350,9 @@ void InitProviderOrtApi(); if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, ::onnxruntime::logging::DataType::SYSTEM)) \ CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::SYSTEM)->Stream() +#define LOGS(logger, severity) \ + LOGS_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime) + #define LOGS_DEFAULT_CATEGORY(severity, category) \ LOGS_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category)