Skip to content

Commit

Permalink
[WebNN] Allow ops to handle ignoring an empty tensor as input (micros…
Browse files Browse the repository at this point in the history
…oft#22972)

### Description
Some ops should allow empty tensor as input, e.g. roi, scales inputs in
Resize
### Motivation and Context
It avoid some unexpected fallback for optional input with empty tensor.
e.g. roi and scales are both optional inputs in Resize, in some models
they have non-empty name but with empty initializer presented as `[0]`,
WebNN currently will fallback all nodes with 0 dimension, which is not
expected.

![image](https://github.com/user-attachments/assets/599ba351-b5f6-49ac-8a1f-69fb28dbaf9b)
  • Loading branch information
Honry authored and ankitm3k committed Dec 11, 2024
1 parent f94351f commit f125b3d
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 4 deletions.
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We
}
}

bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) {
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name,
const logging::Logger& logger, bool allow_empty_input) {
const auto& node_arg_name = node_arg.Name();
const auto* shape_proto = node_arg.Shape();
// Optional tensors can be indicated by an empty name, just ignore it.
Expand All @@ -89,7 +90,7 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name;
return false;
}
if (dim.dim_value() == 0) {
if (dim.dim_value() == 0 && !allow_empty_input) {
LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN";
return false;
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s
return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; });
}

bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name,
const logging::Logger& logger, bool allow_empty_input = false);

// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val&
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* input : node.InputDefs()) {
if (!IsTensorShapeSupported(*input, node_name, logger)) {
if (!IsTensorShapeSupported(*input, node_name, logger, allow_empty_tensor_as_input_)) {
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class BaseOpBuilder : public IOpBuilder {
const logging::Logger& logger) const override final ORT_MUST_USE_RESULT;

protected:
explicit BaseOpBuilder(bool allow_empty_tensor_as_input = false)
: allow_empty_tensor_as_input_(allow_empty_tensor_as_input) {
}
virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const ORT_MUST_USE_RESULT = 0;

Expand Down Expand Up @@ -55,6 +58,8 @@ class BaseOpBuilder : public IOpBuilder {
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;

const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input.
};

} // namespace webnn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ namespace webnn {
class ResizeOpBuilder : public BaseOpBuilder {
// Add operator related.
public:
// Allow roi and scales potentially being empty inputs that are ignored during processing.
ResizeOpBuilder() : BaseOpBuilder(/*allow empty inputs*/ true) {}
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;

private:
Expand Down

0 comments on commit f125b3d

Please sign in to comment.