Skip to content

Commit

Permalink
[WebNN EP] Add support for Op Pad. (#16732)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Support Op Pad for WebNN EP. It aims to support three modes (constant,
reflect and edge). For now, only constant can be tested with Chrome
Canary.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Support more models like SD1.5-VAE-encode.
  • Loading branch information
zesongw authored Jul 20, 2023
1 parent 2bc9fbb commit 0e40049
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 0 deletions.
38 changes: 38 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,43 @@ bool ReadIntArrayFrom1DTensor(const onnx::TensorProto& tensor, std::vector<T>& a
return true;
}

inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::val& scalar, const logging::Logger& logger) {
std::vector<uint8_t> unpacked_tensor;
auto status = onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor);
if (!status.IsOK()) {
LOGS(logger, ERROR) << "Error while unpacking tensor: " << status.ErrorMessage();
return false;
}
switch (tensor.data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
scalar = emscripten::val{*reinterpret_cast<uint8_t*>(unpacked_tensor.data())};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
scalar = emscripten::val{MLFloat16::FromBits(*reinterpret_cast<uint16_t*>(unpacked_tensor.data())).ToFloat()};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
scalar = emscripten::val{*reinterpret_cast<float*>(unpacked_tensor.data())};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
scalar = emscripten::val{*reinterpret_cast<int32_t*>(unpacked_tensor.data())};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
scalar = emscripten::val{*reinterpret_cast<int64_t*>(unpacked_tensor.data())};
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
scalar = emscripten::val{*reinterpret_cast<uint32_t*>(unpacked_tensor.data())};
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
scalar = emscripten::val{*reinterpret_cast<uint64_t*>(unpacked_tensor.data())};
break;
default:
LOGS(logger, ERROR) << "Unsupported data type : " << tensor.data_type();
return false;
break;
}
return true;
}

bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);

// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
Expand Down Expand Up @@ -128,6 +165,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Mul", "mul"},
{"Neg", "neg"},
{"Not", "logicalNot"},
{"Pad", "pad"},
{"Pow", "pow"},
{"Reciprocal", "reciprocal"},
{"ReduceMax", "reduceMax"},
Expand Down
171 changes: 171 additions & 0 deletions onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Intel Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/safeint.h"
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/webnn/builders/helper.h"
#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"

#include "base_op_builder.h"
#include "builder_utils.h"

namespace onnxruntime {
namespace webnn {

class PadOpBuilder : public BaseOpBuilder {
// Add operator related.
public:
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;

private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
private:
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
};

// Add operator related.

// ONNX mode to WebNN mode mapping.
const InlinedHashMap<std::string, std::string> supported_mode = {
{"constant", "constant"},
{"reflect", "reflection"},
{"edge", "edge"},
};

// Skip for pads, constant value, and axes.
void PadOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
for (size_t i = 1; i < node.InputDefs().size(); i++) {
model_builder.AddInitializerToSkip(node.InputDefs()[i]->Name());
model_builder.AddInputToSkip(node.InputDefs()[i]->Name());
}
}

Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& initializers = model_builder.GetInitializerTensors();
ORT_RETURN_IF(input_defs.size() < 1, "Pad has no inputs");
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape");

emscripten::val options = emscripten::val::object();

NodeAttrHelper helper(node);
const auto pad_mode = helper.Get("mode", std::string("constant"));
std::vector<int32_t> start_padding;
std::vector<int32_t> end_padding;
ORT_RETURN_IF(supported_mode.find(pad_mode) == supported_mode.end(), "WebNN dose not support mode", pad_mode);
const auto webnn_mode = supported_mode.find(pad_mode)->second;
options.set("mode", emscripten::val(webnn_mode));

const auto opset = node.SinceVersion();
// From opset 11, pads, constant value and axes are inputs.
if (opset >= 11) {
ORT_RETURN_IF(input_defs.size() < 2, "Pads is required at opset ", opset);
std::vector<int64_t> pads;
const auto& pads_tensor = *initializers.at(input_defs[1]->Name());
ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(pads_tensor, pads, logger), "Error while read pads tensor");

// Constant value and axes are optional.
if (input_defs.size() >= 3) {
const auto value_tensor = *initializers.at(input_defs[2]->Name());
emscripten::val value = emscripten::val::object();
ORT_RETURN_IF_NOT(ReadScalarTensorData(value_tensor, value, logger), "Cannot read constant value");
options.set("value", value);
}

if (input_defs.size() == 4) {
const auto input_rank = input_shape.size();
std::vector<int64_t> axes;
const auto& axes_tensor = *initializers.at(input_defs[3]->Name());
ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(axes_tensor, axes, logger), "Error while read axes tensor");
std::vector<size_t> axes_index;
std::transform(
axes.begin(), axes.end(), std::back_inserter(axes_index),
[input_rank](int64_t axis) -> int32_t { return SafeInt<int32_t>(HandleNegativeAxis(axis, input_rank)); });
start_padding.resize(input_rank, 0);
end_padding.resize(input_rank, 0);
for (size_t i = 0; i < axes_index.size(); i++) {
size_t index = axes_index[i];
start_padding[index] = SafeInt<int32_t>(pads[i]);
end_padding[index] = SafeInt<int32_t>(pads[i + pads.size() / 2]);
}
} else {
std::transform(pads.begin(), pads.begin() + pads.size() / 2, std::back_inserter(start_padding),
[](int64_t axis) -> int32_t { return SafeInt<int32_t>(axis); });

std::transform(pads.begin() + pads.size() / 2, pads.end(), std::back_inserter(end_padding),
[](int64_t axis) -> int32_t { return SafeInt<int32_t>(axis); });
}
} else {
// Before opset 11, pads, constant value are attributes.
ORT_RETURN_IF_NOT(helper.HasAttr("pads"), "Pads is required as attribute in opset ", opset);
const auto pads = helper.Get("pads", std::vector<int>());
const auto value = helper.Get("value", 0.0f);
start_padding = std::vector<int32_t>(pads.begin(), pads.begin() + pads.size() / 2);
end_padding = std::vector<int32_t>(pads.begin() + pads.size() / 2, pads.end());
options.set("value", value);
}

emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("pad", input,
emscripten::val::array(start_padding),
emscripten::val::array(end_padding),
options);

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
}

// Operator support related.
bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
const Node& node,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& op_type = node.OpType();
const auto& input_defs = node.InputDefs();
const auto opset = node.SinceVersion();

NodeAttrHelper helper(node);
const auto pad_mode = helper.Get("mode", "constant");
if (supported_mode.find(pad_mode) == supported_mode.end()) {
LOGS(logger, VERBOSE) << op_type << " WebNN does not support mode " << pad_mode;
return false;
}

if (input_defs.size() < 1) {
LOGS(logger, VERBOSE) << op_type << " requires at least one input (data)";
return false;
}

if (opset >= 11) {
if (input_defs.size() < 2) {
LOGS(logger, VERBOSE) << op_type << " at opset " << opset << " requires at least two inputs (data and pads)";
return false;
}
for (size_t i = 1; i < input_defs.size(); i++) {
if (!Contains(initializers, input_defs[i]->Name())) {
LOGS(logger, VERBOSE) << "Input [" << input_defs[i]->Name() << "] must be known as initializer";
return false;
}
}
}

return true;
} // namespace webnn

void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<PadOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}

} // namespace webnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateNormalizationOpBuilder("LayerNormalization", op_registrations);
}

{ // Pad
CreatePadOpBuilder("Pad", op_registrations);
}

{ // Pool
CreatePoolOpBuilder("GlobalAveragePool", op_registrations);
CreatePoolOpBuilder("GlobalMaxPool", op_registrations);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& o
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Expand Down

0 comments on commit 0e40049

Please sign in to comment.