Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN] Allow ops to handle ignoring an empty tensor as input #22972

Merged
merged 2 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We
}
}

bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) {
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name,
const logging::Logger& logger, bool allow_empty_input) {
const auto& node_arg_name = node_arg.Name();
const auto* shape_proto = node_arg.Shape();
// Optional tensors can be indicated by an empty name, just ignore it.
Expand All @@ -89,7 +90,7 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name;
return false;
}
if (dim.dim_value() == 0) {
if (dim.dim_value() == 0 && !allow_empty_input) {
LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN";
return false;
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s
return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; });
}

bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name,
const logging::Logger& logger, bool allow_empty_input = false);

// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val&
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* input : node.InputDefs()) {
if (!IsTensorShapeSupported(*input, node_name, logger)) {
if (!IsTensorShapeSupported(*input, node_name, logger, allow_empty_tensor_as_input_)) {
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class BaseOpBuilder : public IOpBuilder {
const logging::Logger& logger) const override final ORT_MUST_USE_RESULT;

protected:
explicit BaseOpBuilder(bool allow_empty_tensor_as_input = false)
: allow_empty_tensor_as_input_(allow_empty_tensor_as_input) {
}
virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const ORT_MUST_USE_RESULT = 0;

Expand Down Expand Up @@ -55,6 +58,8 @@ class BaseOpBuilder : public IOpBuilder {
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;

const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input.
};

} // namespace webnn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ namespace webnn {
class ResizeOpBuilder : public BaseOpBuilder {
// Add operator related.
public:
// Allow roi and scales potentially being empty inputs that are ignored during processing.
ResizeOpBuilder() : BaseOpBuilder(/*allow empty inputs*/ true) {}
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;

private:
Expand Down
Loading