Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN] Improve the util function of creating WebNN constant MLOperand #22935

Merged
merged 3 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,12 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
if (input_defs.size() >= 3) {
x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name());
} else {
x_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
x_zero_point = model_builder.CreateOrGetConstant<uint8_t>(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0);
}
if (input_defs.size() >= 4) {
w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name());
} else {
w_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
w_zero_point = model_builder.CreateOrGetConstant<uint8_t>(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0);
}
output = model_builder.GetBuilder().call<emscripten::val>("conv2dInteger",
input, x_zero_point, filter, w_zero_point, options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,14 @@
std::vector<int64_t> mask_shape;
ORT_RETURN_IF_NOT(GetShape(*output_defs[1], mask_shape, logger), "Cannot get mask output's shape");
std::vector<uint32_t> dims = GetVecUint32FromVecInt64(mask_shape);

emscripten::val desc = emscripten::val::object();
desc.set("dataType", "uint8");
desc.set("dimensions", emscripten::val::array(dims));
desc.set("shape", emscripten::val::array(dims));
const auto num_elements = narrow<uint32_t>(Product(mask_shape));
emscripten::val ones_buffer = emscripten::val::global("Uint8Array").new_(num_elements);
ones_buffer.call<void>("fill", 1);

emscripten::val mask_output = model_builder.GetBuilder().call<emscripten::val>("constant", desc, ones_buffer);
emscripten::val one_constant = model_builder.CreateOrGetConstant<uint8_t>(
ONNX_NAMESPACE::TensorProto_DataType_BOOL, 1, dims);

emscripten::val options = emscripten::val::object();
options.set("label", output_defs[1]->Name() + "_identity");
// Add additional identity op in case the mask is the output of a WebNN graph,
// beacuse WebNN does not support a constant operand as output.

Check warning on line 68 in onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "beacuse" is a misspelling of "because" Raw Output: ./onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc:68:7: "beacuse" is a misspelling of "because"
mask_output = model_builder.GetBuilder().call<emscripten::val>("identity", mask_output, options);
emscripten::val mask_output = model_builder.GetBuilder().call<emscripten::val>("identity", one_constant, options);
model_builder.AddOperand(output_defs[1]->Name(), std::move(mask_output));
}
return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
if (input_defs.size() >= 3) {
a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name());
} else {
a_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
a_zero_point = model_builder.CreateOrGetConstant<uint8_t>(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0);
}
if (input_defs.size() >= 4) {
b_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name());
} else {
b_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
b_zero_point = model_builder.CreateOrGetConstant<uint8_t>(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0);
}
output = model_builder.GetBuilder().call<emscripten::val>("matmulInteger",
a,
Expand Down
11 changes: 6 additions & 5 deletions onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto input_data_type = input_defs[0]->TypeAsProto()->tensor_type().elem_type();
int32_t input_data_type;
ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_data_type, logger), "Cannot get input type");
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
const auto node_name = node.Name();
emscripten::val wnn_builder = model_builder.GetBuilder();
Expand All @@ -42,10 +43,10 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,

// Prepare WebNN constants for alpha, beta, bias attributes.
// Assume T is float, because input_data_type has been limited to float32 and float16 in 'hasSupportedInitsImpl'.
emscripten::val alpha_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, alpha);
emscripten::val beta_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, beta);
emscripten::val bias_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, bias);
emscripten::val pow1_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, 2);
emscripten::val alpha_constant = model_builder.CreateOrGetConstant<float>(input_data_type, alpha);
emscripten::val beta_constant = model_builder.CreateOrGetConstant<float>(input_data_type, beta);
emscripten::val bias_constant = model_builder.CreateOrGetConstant<float>(input_data_type, bias);
emscripten::val pow1_constant = model_builder.CreateOrGetConstant<float>(input_data_type, 2);

/**
WebNN doesn't support LRN. So decompose it into a series of ops:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,15 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
X --> Pow --> ReduceMean --> Add --> Sqrt --> Div -> Mul
^ ^ ^ ^ ^
| | | | |
Y:2 axis B:epsilon A:X A:scale
Y:2 axis B:epsilon A:X A:scale
*/

int32_t input_type;
ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_type, logger), "Cannot get input type");
emscripten::val common_options = emscripten::val::object();

// Pow
emscripten::val pow_constant_desc = emscripten::val::object();
ORT_RETURN_IF_NOT(SetWebnnDataType(pow_constant_desc, input_type), "Unsupported data type");
pow_constant_desc.set("shape", emscripten::val::array());
emscripten::val pow_buffer = emscripten::val::global("Float32Array").new_(1);
pow_buffer.set(0, 2);
emscripten::val pow_constant =
model_builder.GetBuilder().call<emscripten::val>("constant", pow_constant_desc, pow_buffer);
emscripten::val pow_constant = model_builder.CreateOrGetConstant<float>(input_type, 2);
common_options.set("label", node.Name() + "_pow");
emscripten::val pow =
model_builder.GetBuilder().call<emscripten::val>("pow", input, pow_constant, common_options);
Expand All @@ -127,13 +121,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
emscripten::val reduce_mean = model_builder.GetBuilder().call<emscripten::val>("reduceMean", pow, reduce_options);

// Add
emscripten::val add_constant_desc = emscripten::val::object();
ORT_RETURN_IF_NOT(SetWebnnDataType(add_constant_desc, input_type), "Unsupported data type");
add_constant_desc.set("shape", emscripten::val::array());
emscripten::val add_buffer = emscripten::val::global("Float32Array").new_(1);
add_buffer.set(0, epsilon);
emscripten::val add_constant =
model_builder.GetBuilder().call<emscripten::val>("constant", add_constant_desc, add_buffer);
emscripten::val add_constant = model_builder.CreateOrGetConstant<float>(input_type, epsilon);
common_options.set("label", node.Name() + "_add");
emscripten::val add =
model_builder.GetBuilder().call<emscripten::val>("add", reduce_mean, add_constant, common_options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
// zero_point has the same shape as the scale tensor.
zero_point_shape = GetVecUint32FromVecInt64(scale_shape);
}
zero_point = model_builder.GetZeroConstant(zero_point_type, zero_point_shape);
// Create a zero constant with the same shape as the scale tensor.
// The zero value has been pre-processed in the CreateOrGetConstant function,
// so the type of T is not relevant here.
zero_point = model_builder.CreateOrGetConstant<uint8_t>(zero_point_type, 0, zero_point_shape);
}

emscripten::val options = emscripten::val::object();
Expand Down
68 changes: 0 additions & 68 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"

#include <sstream>
#include <utility>

namespace onnxruntime {
Expand Down Expand Up @@ -385,73 +384,6 @@ void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& op
wnn_operands_.insert(std::make_pair(name, operand));
}

// Get the zero constant with shape.
const emscripten::val& ModelBuilder::GetZeroConstant(const int32_t& data_type,
const std::vector<uint32_t>& shape) {
std::string name = "webnn_zero_constant_" + std::to_string(data_type);
emscripten::val dims = emscripten::val::array();
if (!shape.empty()) {
dims = emscripten::val::array(shape);
std::ostringstream name_stream;
name_stream << name;
for (const auto& dim : shape) {
name_stream << "_" << dim;
}
name = name_stream.str();
}
// If the operand does not exist, create it.
if (wnn_operands_.find(name) == wnn_operands_.end()) {
emscripten::val desc = emscripten::val::object();
desc.set("dimensions", dims);
desc.set("shape", dims);
emscripten::val zero_buffer = emscripten::val::undefined();
if (!SetWebnnDataType(desc, data_type)) {
ORT_THROW("Unsupported data type: " + std::to_string(data_type));
}
auto num_elements = Product(shape);
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
// For WebNN int4 and uint4 tensors are stored in Uint8Array,
// so we need to adjust the number of elements.
num_elements = (num_elements + 1) / 2;
zero_buffer = emscripten::val::global("Uint8Array").new_(num_elements);
break;
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
zero_buffer = emscripten::val::global("Uint8Array").new_(num_elements);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
zero_buffer = emscripten::val::global("Int8Array").new_(num_elements);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
zero_buffer = emscripten::val::global("Uint16Array").new_(num_elements);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
zero_buffer = emscripten::val::global("Float32Array").new_(num_elements);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
zero_buffer = emscripten::val::global("Int32Array").new_(num_elements);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
zero_buffer = emscripten::val::global("BigInt64Array").new_(num_elements);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
zero_buffer = emscripten::val::global("Uint32Array").new_(num_elements);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
zero_buffer = emscripten::val::global("BigUint64Array").new_(num_elements);
break;
default:
break;
}

emscripten::val zero_constant = wnn_builder_.call<emscripten::val>("constant", desc, zero_buffer);
wnn_operands_.insert(std::make_pair(name, zero_constant));
}
return wnn_operands_.at(name);
}

void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) {
skipped_initializers_.insert(tensor_name);
}
Expand Down
110 changes: 69 additions & 41 deletions onnxruntime/core/providers/webnn/builders/model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "core/framework/execution_provider.h"
#include "core/providers/webnn/builders/helper.h"

#include <sstream>

Check warning on line 14 in onnxruntime/core/providers/webnn/builders/model_builder.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: model_builder.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/webnn/builders/model_builder.h:14: Found C++ system header after other header. Should be: model_builder.h, c system, c++ system, other. [build/include_order] [4]
#include <emscripten.h>
#include <emscripten/val.h>

Expand Down Expand Up @@ -38,11 +39,10 @@
const emscripten::val& GetOpSupportLimits() const { return wnn_limits_; }

void AddOperand(const std::string& name, const emscripten::val& operand);
const emscripten::val& GetZeroConstant(
const int32_t& data_type, const std::vector<uint32_t>& shape = {});

template <typename T>
const emscripten::val& CreateOrGetScalarConstant(const int32_t& data_type, T value);
const emscripten::val& CreateOrGetConstant(const int32_t& data_type, T value,
const std::vector<uint32_t>& shape = {});
Honry marked this conversation as resolved.
Show resolved Hide resolved

// Use the buffers to persist WebNN allocated data like transposed weight.
// It ensures the validity during inference session.
Expand Down Expand Up @@ -103,11 +103,12 @@
static const IOpBuilder* GetOpBuilder(const Node& node);
};

// Create a scalar constant MLOperand of the specified value and data type.
// Workaround for builer.constant(type, value) method since it has not been implemented now.
// Create or retrieve one of the following:
// - A WebNN constant MLOperand filled with the specified value, data type, and shape.
// - A WebNN scalar constant MLOperand with the specified value and data type.
// For scalar constant, it is workaround for builer.constant(type, value) method since
// it has not been implemented now.
Honry marked this conversation as resolved.
Show resolved Hide resolved
// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-type-value
// BTW, the spec is discussing if the builder.constant(type, value) should be dropped at
// https://github.com/webmachinelearning/webnn/issues/475. Fix me according to the spec decision.
//
// This function enforces a mapping between the data_type and the value types:
// - TensorProto_DataType_INT4 <-> int8_t
Expand All @@ -122,69 +123,96 @@
// - TensorProto_DataType_UINT32 <-> uint32_t
// - TensorProto_DataType_UINT64 <-> uint64_t
template <typename T>
const emscripten::val& ModelBuilder::CreateOrGetScalarConstant(const int32_t& data_type, T value) {
std::string name = "webnn_scalar_constant_" + std::to_string(data_type) + "_" + std::to_string(value);
emscripten::val desc = emscripten::val::object();
desc.set("shape", emscripten::val::array());
emscripten::val scalar_buffer = emscripten::val::undefined();
uint16_t value_uint16 = 0;
uint8_t value_uint8 = 0;
if (!SetWebnnDataType(desc, data_type)) {
ORT_THROW("Unsupported data type: " + std::to_string(data_type));
const emscripten::val& ModelBuilder::CreateOrGetConstant(const int32_t& data_type, T value,
const std::vector<uint32_t>& shape) {

Check warning on line 127 in onnxruntime/core/providers/webnn/builders/model_builder.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/model_builder.h:127: Add #include <vector> for vector<> [build/include_what_you_use] [4]
std::string name = "webnn_constant_" + std::to_string(data_type) + "_" + std::to_string(value);

Check warning on line 128 in onnxruntime/core/providers/webnn/builders/model_builder.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/model_builder.h:128: Add #include <string> for string [build/include_what_you_use] [4]
emscripten::val dims = emscripten::val::array();
if (!shape.empty()) {
dims = emscripten::val::array(shape);
std::ostringstream name_stream;
name_stream << name;
for (const auto& dim : shape) {
name_stream << "_" << dim;
}
name = name_stream.str();
}

// If the operand does not exist, create it.
if (wnn_operands_.find(name) == wnn_operands_.end()) {
emscripten::val desc = emscripten::val::object();
desc.set("shape", dims);
desc.set("dimensions", dims);
emscripten::val buffer = emscripten::val::undefined();
if (!SetWebnnDataType(desc, data_type)) {
ORT_THROW("Unsupported data type: " + std::to_string(data_type));
}
auto num_elements = Product(shape);
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
scalar_buffer = emscripten::val::global("Uint8Array").new_(1);
value_uint8 = PackInt8ToUint8AsNibble(value, data_type);
scalar_buffer.call<void>("fill", emscripten::val(value_uint8));
// For WebNN int4 and uint4 tensors are stored in Uint8Array,
// so we need to adjust the number of elements.
num_elements = (num_elements + 1) / 2;
buffer = emscripten::val::global("Uint8Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(PackInt8ToUint8AsNibble(value, data_type)));
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
scalar_buffer = emscripten::val::global("Uint8Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value ? 1 : 0));
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
scalar_buffer = emscripten::val::global("Uint8Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
buffer = emscripten::val::global("Uint8Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(value));
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
scalar_buffer = emscripten::val::global("Int8Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
buffer = emscripten::val::global("Int8Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(value));
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
scalar_buffer = emscripten::val::global("Uint16Array").new_(1);
value_uint16 = PackFloat32ToUint16AsFloat16(value);
scalar_buffer.call<void>("fill", emscripten::val(value_uint16));
buffer = emscripten::val::global("Uint16Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(PackFloat32ToUint16AsFloat16(value)));
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
scalar_buffer = emscripten::val::global("Float32Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
buffer = emscripten::val::global("Float32Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(value));
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
scalar_buffer = emscripten::val::global("Int32Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
buffer = emscripten::val::global("Int32Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(value));
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
scalar_buffer = emscripten::val::global("Uint32Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
buffer = emscripten::val::global("Uint32Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(value));
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
scalar_buffer = emscripten::val::global("BigInt64Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val::global("BigInt")(value));
buffer = emscripten::val::global("BigInt64Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val::global("BigInt")(value));
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
scalar_buffer = emscripten::val::global("BigUint64Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val::global("BigInt")(value));
buffer = emscripten::val::global("BigUint64Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val::global("BigInt")(value));
}
break;
default:
break;
}

const emscripten::val scalar_constant = wnn_builder_.call<emscripten::val>("constant", desc, scalar_buffer);
wnn_operands_.insert(std::make_pair(name, scalar_constant));
const emscripten::val constant = wnn_builder_.call<emscripten::val>("constant", desc, buffer);
wnn_operands_.insert(std::make_pair(name, constant));

Check warning on line 215 in onnxruntime/core/providers/webnn/builders/model_builder.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for make_pair [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/model_builder.h:215: Add #include <utility> for make_pair [build/include_what_you_use] [4]
}

return wnn_operands_.at(name);
Expand Down
Loading