From f125b3db73b02a6e67587b08fe60465ae7c9e3a5 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Sat, 7 Dec 2024 09:58:15 +0800 Subject: [PATCH] [WebNN] Allow ops to handle ignoring an empty tensor as input (#22972) ### Description Some ops should allow empty tensor as input, e.g. roi, scales inputs in Resize ### Motivation and Context It avoid some unexpected fallback for optional input with empty tensor. e.g. roi and scales are both optional inputs in Resize, in some models they have non-empty name but with empty initializer presented as `[0]`, WebNN currently will fallback all nodes with 0 dimension, which is not expected. ![image](https://github.com/user-attachments/assets/599ba351-b5f6-49ac-8a1f-69fb28dbaf9b) --- onnxruntime/core/providers/webnn/builders/helper.cc | 5 +++-- onnxruntime/core/providers/webnn/builders/helper.h | 3 ++- .../core/providers/webnn/builders/impl/base_op_builder.cc | 2 +- .../core/providers/webnn/builders/impl/base_op_builder.h | 5 +++++ .../core/providers/webnn/builders/impl/resize_op_builder.cc | 2 ++ 5 files changed, 13 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 537e552af2763..f36f8283e9bf6 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -69,7 +69,8 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We } } -bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) { +bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, + const logging::Logger& logger, bool allow_empty_input) { const auto& node_arg_name = node_arg.Name(); const auto* shape_proto = node_arg.Shape(); // Optional tensors can be indicated by an empty name, just ignore it. @@ -89,7 +90,7 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n << "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name; return false; } - if (dim.dim_value() == 0) { + if (dim.dim_value() == 0 && !allow_empty_input) { LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN"; return false; } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 23489f142df3e..7fdfc5aefa798 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -181,7 +181,8 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; }); } -bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger); +bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, + const logging::Logger& logger, bool allow_empty_input = false); // Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index 704896e43a7e1..70fa0f9516c5c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -45,7 +45,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& const logging::Logger& logger) const { const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { - if (!IsTensorShapeSupported(*input, node_name, logger)) { + if (!IsTensorShapeSupported(*input, node_name, logger, allow_empty_tensor_as_input_)) { return false; } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index a632876dab2b9..9412fa8026fb3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -22,6 +22,9 @@ class BaseOpBuilder : public IOpBuilder { const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; protected: + explicit BaseOpBuilder(bool allow_empty_tensor_as_input = false) + : allow_empty_tensor_as_input_(allow_empty_tensor_as_input) { + } virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const ORT_MUST_USE_RESULT = 0; @@ -55,6 +58,8 @@ class BaseOpBuilder : public IOpBuilder { bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + + const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input. }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index 408045571e422..00f8cff25ccf5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -21,6 +21,8 @@ namespace webnn { class ResizeOpBuilder : public BaseOpBuilder { // Add operator related. public: + // Allow roi and scales potentially being empty inputs that are ignored during processing. + ResizeOpBuilder() : BaseOpBuilder(/*allow empty inputs*/ true) {} void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; private: