diff --git a/DEPS b/DEPS index 1aa5a3a13..4e1ccc2c1 100644 --- a/DEPS +++ b/DEPS @@ -62,7 +62,7 @@ deps = { 'url': '{github_git}/oneapi-src/oneDNN.git@4a129541fd4e67e6897072186ea2817a3154eddd', }, 'third_party/XNNPACK': { - 'url': '{github_git}/google/XNNPACK.git@60fc61373f21f0ad3164cc719de464f4b787dc04' + 'url': '{github_git}/google/XNNPACK.git@a9c0465458e185d8189f483b1bdd8965a3341838' }, 'third_party/onnxruntime': { 'url': '{github_git}/microsoft/onnxruntime.git@0d9030e79888d1d5828730b254fedc53c7b640c1', diff --git a/README.md b/README.md index a0ca63a9f..40803d187 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ Currently "cpu", "gpu" and "default" are supported, more devices are to be suppo **Notes**: * For OpenVINO backend, please [install 2021.4 version](https://docs.openvinotoolkit.org/2021.4/openvino_docs_install_guides_installing_openvino_linux.html#install-openvino) and [set the environment variables](https://docs.openvinotoolkit.org/2021.4/openvino_docs_install_guides_installing_openvino_linux.html#set-the-environment-variables) before running the end2end tests. - * The current implementation of XNNPACK, oneDNN and MLAS backends is mainly for the investigation of WebNN [Operation Level Execution + * The current implementation of oneDNN and MLAS backends is mainly for the investigation of WebNN [Operation Level Execution ](https://webmachinelearning.github.io/webnn/#usecase-op-level-exec) use case. So only a limited set of tests (such as of conv2d) is expected to pass. ### Run examples diff --git a/src/tests/end2end/GemmTests.cpp b/src/tests/end2end/GemmTests.cpp index 6ff6ba90d..d3339cc76 100644 --- a/src/tests/end2end/GemmTests.cpp +++ b/src/tests/end2end/GemmTests.cpp @@ -31,10 +31,16 @@ class GemmTests : public WebnnTest { const std::vector& bData, const std::vector& expectedShape, const std::vector& expectedValue, - const Options* options = nullptr) { + const Options* options = nullptr, + bool constantWeight = false) { const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", aShape); - const wnn::Operand b = utils::BuildInput(builder, "b", bShape); + wnn::Operand b; + if (constantWeight) { + b = utils::BuildConstant(builder, bShape, bData.data(), bData.size() * sizeof(float)); + } else { + b = utils::BuildInput(builder, "b", bShape); + } wnn::GemmOptions gemmOptions = {}; if (options != nullptr) { if (!options->cData.empty()) { @@ -52,7 +58,11 @@ class GemmTests : public WebnnTest { const wnn::Graph graph = utils::Build(builder, {{"c", gemm}}); ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(expectedShape)); - utils::Compute(graph, {{"a", aData}, {"b", bData}}, {{"c", result}}); + if (constantWeight) { + utils::Compute(graph, {{"a", aData}}, {{"c", result}}); + } else { + utils::Compute(graph, {{"a", aData}, {"b", bData}}, {{"c", result}}); + } EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } }; @@ -192,6 +202,29 @@ TEST_F(GemmTests, NoBias) { TestGemm(inputAShape, inputAData, inputBShape, inputBData, expectedShape, expectedValue); } +TEST_F(GemmTests, NoBiasWithConstantWeight) { + const std::vector inputAShape = {2, 10}; + const std::vector inputAData = { + 0.97596496, 0.47531518, 0.7147315, 0.14236908, 0.06151228, 0.05889508, 0.3534669, + 0.31915423, 0.61336106, 0.5946216, 0.21969128, 0.7347848, 0.4087221, 0.00412959, + 0.77303815, 0.6495765, 0.3174799, 0.62841094, 0.7002717, 0.63384914, + }; + const std::vector inputBShape = {10, 3}; + const std::vector inputBData = { + 0.51739925, 0.25108355, 0.31373033, 0.6488124, 0.9777175, 0.13308926, + 0.47903556, 0.23692878, 0.0822504, 0.3080891, 0.51966125, 0.969734, + 0.6691261, 0.59346807, 0.7651862, 0.48655444, 0.48373327, 0.2799068, + 0.35760838, 0.19906454, 0.3612888, 0.11448191, 0.19188708, 0.00769753, + 0.3161914, 0.323555, 0.17573832, 0.79587144, 0.91238266, 0.5517277, + }; + const std::vector expectedShape = {2, 3}; + const std::vector expectedValue = { + 2.0995352, 1.8906747, 1.1958704, 2.5321422, 2.6342242, 1.5699927, + }; + TestGemm(inputAShape, inputAData, inputBShape, inputBData, expectedShape, expectedValue, + nullptr, true); +} + TEST_F(GemmTests, ScalarBias) { const std::vector inputAShape = {2, 3}; const std::vector inputAData = { @@ -213,6 +246,27 @@ TEST_F(GemmTests, ScalarBias) { &options); } +TEST_F(GemmTests, BiasWithConstantWeight) { + const std::vector inputAShape = {2, 3}; + const std::vector inputAData = { + 0.41595492, 0.7063231, 0.3784654, 0.3524597, 0.41936764, 0.08190536, + }; + const std::vector inputBShape = {3, 4}; + const std::vector inputBData = { + 0.38356313, 0.92939967, 0.06164686, 0.09034675, 0.34704673, 0.9492532, + 0.7738587, 0.93576515, 0.49937814, 0.38543963, 0.02364575, 0.80216527, + }; + const std::vector expectedShape = {2, 4}; + const std::vector expectedValue = { + 3.7336695, 4.3429437, 3.7211857, 4.1421247, 3.4616325, 3.8972316, 3.4881961, 3.6299748, + }; + Options options; + options.cShape = {4}; + options.cData = {3.14, 3.14, 3.14, 3.14}; + TestGemm(inputAShape, inputAData, inputBShape, inputBData, expectedShape, expectedValue, + &options, true); +} + TEST_F(GemmTests, BroadcastingBias) { const std::vector inputAShape = {3, 7}; const std::vector inputAData = { @@ -287,3 +341,28 @@ TEST_F(GemmTests, bTranspose) { TestGemm(inputAShape, inputAData, inputBShape, inputBData, expectedShape, expectedValue, &options); } + +TEST_F(GemmTests, bTransposeWithConstantWeight) { + const std::vector inputAShape = {3, 6}; + const std::vector inputAData = { + 0.4520783, 0.25709572, 0.28996432, 0.03766193, 0.0546827, 0.46305302, + 0.91171485, 0.48380807, 0.09058774, 0.6646215, 0.35773644, 0.03604647, + 0.21229707, 0.18758385, 0.01589681, 0.9606218, 0.08803706, 0.18099776, + }; + const std::vector inputBShape = {4, 6}; + const std::vector inputBData = { + 0.1482661, 0.27676222, 0.10893039, 0.8347901, 0.7146212, 0.7316929, + 0.97991717, 0.97123116, 0.69798464, 0.8436566, 0.9630883, 0.23252074, + 0.09898344, 0.08882044, 0.90780985, 0.7116153, 0.5819304, 0.6742051, + 0.5233705, 0.5594687, 0.963364, 0.1351259, 0.8119938, 0.13756031, + }; + const std::vector expectedShape = {3, 4}; + const std::vector expectedValue = { + 0.57909805, 1.0871967, 0.7016311, 0.77297145, 1.1157843, 2.340149, + 0.92088836, 1.2203549, 1.0823897, 1.3386247, 0.9089607, 0.45756027, + }; + Options options; + options.bTranspose = true; + TestGemm(inputAShape, inputAData, inputBShape, inputBData, expectedShape, expectedValue, + &options, true); +} diff --git a/src/tests/end2end/SplitTests.cpp b/src/tests/end2end/SplitTests.cpp index 0d035d549..8cc2c95db 100644 --- a/src/tests/end2end/SplitTests.cpp +++ b/src/tests/end2end/SplitTests.cpp @@ -54,24 +54,29 @@ class SplitTests : public WebnnTest { } }; -TEST_F(SplitTests, SplitByDefault) { +TEST_F(SplitTests, SplitEvenByDefault) { testSplit({6}, {1, 2, 3, 4, 5, 6}, {3}, { {{2}, {1, 2}}, {{2}, {3, 4}}, {{2}, {5, 6}}, }); +} +TEST_F(SplitTests, SplitByDefault) { testSplit({6}, {1, 2, 3, 4, 5, 6}, {2, 4}, {{{2}, {1, 2}}, {{4}, {3, 4, 5, 6}}}); } -TEST_F(SplitTests, SplitOneDimension) { +TEST_F(SplitTests, SplitEvenOneDimension) { testSplit({2, 6}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2}, { {{2, 3}, {1, 2, 3, 7, 8, 9}}, {{2, 3}, {4, 5, 6, 10, 11, 12}}, }, 1); +} + +TEST_F(SplitTests, SplitOneDimension) { testSplit({2, 6}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 4}, { {{2, 2}, {1, 2, 7, 8}}, diff --git a/src/webnn_native/BUILD.gn b/src/webnn_native/BUILD.gn index 33cf6c3ca..123197b45 100644 --- a/src/webnn_native/BUILD.gn +++ b/src/webnn_native/BUILD.gn @@ -303,6 +303,8 @@ source_set("webnn_native_sources") { if (webnn_enable_xnnpack) { sources += [ + "xnnpack/BackendXNN.cpp", + "xnnpack/BackendXNN.h", "xnnpack/ContextXNN.cpp", "xnnpack/ContextXNN.h", "xnnpack/GraphXNN.cpp", diff --git a/src/webnn_native/Instance.cpp b/src/webnn_native/Instance.cpp index 5497b3aab..b3ac4a766 100644 --- a/src/webnn_native/Instance.cpp +++ b/src/webnn_native/Instance.cpp @@ -53,6 +53,11 @@ namespace webnn_native { BackendConnection* Connect(InstanceBase* instance); } #endif // defined(WEBNN_ENABLE_BACKEND_MLAS) +#if defined(WEBNN_ENABLE_BACKEND_XNNPACK) + namespace xnnpack { + BackendConnection* Connect(InstanceBase* instance); + } +#endif // defined(WEBNN_ENABLE_BACKEND_XNNPACK) namespace { @@ -73,6 +78,9 @@ namespace webnn_native { #if defined(WEBNN_ENABLE_BACKEND_MLAS) enabledBackends.set(wnn::BackendType::MLAS); #endif // defined(WEBNN_ENABLE_BACKEND_MLAS) +#if defined(WEBNN_ENABLE_BACKEND_XNNPACK) + enabledBackends.set(wnn::BackendType::XNNPACK); +#endif // defined(WEBNN_ENABLE_BACKEND_XNNPACK) return enabledBackends; } @@ -137,6 +145,12 @@ namespace webnn_native { break; #endif // defined(WEBNN_ENABLE_BACKEND_MLAS) +#if defined(WEBNN_ENABLE_BACKEND_XNNPACK) + case wnn::BackendType::XNNPACK: + Register(xnnpack::Connect(this), wnn::BackendType::XNNPACK); + break; +#endif // defined(WEBNN_ENABLE_BACKEND_XNNPACK) + default: UNREACHABLE(); } @@ -156,6 +170,8 @@ namespace webnn_native { return mBackends[wnn::BackendType::OneDNN]->CreateContext(options); } else if (mBackends.find(wnn::BackendType::MLAS) != mBackends.end()) { return mBackends[wnn::BackendType::MLAS]->CreateContext(options); + } else if (mBackends.find(wnn::BackendType::XNNPACK) != mBackends.end()) { + return mBackends[wnn::BackendType::XNNPACK]->CreateContext(options); } UNREACHABLE(); return nullptr; diff --git a/src/webnn_native/xnnpack/BackendXNN.cpp b/src/webnn_native/xnnpack/BackendXNN.cpp new file mode 100644 index 000000000..9a235a02d --- /dev/null +++ b/src/webnn_native/xnnpack/BackendXNN.cpp @@ -0,0 +1,56 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "webnn_native/xnnpack/BackendXNN.h" + +#include "common/Log.h" +#include "webnn_native/Instance.h" +#include "webnn_native/xnnpack/ContextXNN.h" + +namespace webnn_native::xnnpack { + + Backend::Backend(InstanceBase* instance) + : BackendConnection(instance, wnn::BackendType::XNNPACK) { + } + + MaybeError Backend::Initialize() { + return {}; + } + + ContextBase* Backend::CreateContext(ContextOptions const* options) { + if (options->devicePreference == wnn::DevicePreference::Gpu) { + dawn::ErrorLog() << "XNNPACK backend only supports CPU device."; + return nullptr; + } + Ref context = AcquireRef(new Context(options)); + xnn_status status = reinterpret_cast(context.Get())->Init(); + if (status != xnn_status_success) { + dawn::ErrorLog() << "Failed to init XNNPACK:" << status; + return nullptr; + } + return context.Detach(); + } + + BackendConnection* Connect(InstanceBase* instance) { + Backend* backend = new Backend(instance); + + if (instance->ConsumedError(backend->Initialize())) { + delete backend; + return nullptr; + } + + return backend; + } + +} // namespace webnn_native::xnnpack diff --git a/src/webnn_native/xnnpack/BackendXNN.h b/src/webnn_native/xnnpack/BackendXNN.h new file mode 100644 index 000000000..1a06c9cb2 --- /dev/null +++ b/src/webnn_native/xnnpack/BackendXNN.h @@ -0,0 +1,36 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef WEBNN_NATIVE_XNNPACK_BACKENDXNN_H_ +#define WEBNN_NATIVE_XNNPACK_BACKENDXNN_H_ + +#include "webnn_native/BackendConnection.h" +#include "webnn_native/Context.h" +#include "webnn_native/Error.h" + +#include + +namespace webnn_native::xnnpack { + + class Backend : public BackendConnection { + public: + Backend(InstanceBase* instance); + + MaybeError Initialize(); + ContextBase* CreateContext(ContextOptions const* options = nullptr) override; + }; + +} // namespace webnn_native::xnnpack + +#endif // WEBNN_NATIVE_XNNPACK_BACKENDXNN_H_ diff --git a/src/webnn_native/xnnpack/ContextXNN.cpp b/src/webnn_native/xnnpack/ContextXNN.cpp index 83f355daf..dcb508a0b 100644 --- a/src/webnn_native/xnnpack/ContextXNN.cpp +++ b/src/webnn_native/xnnpack/ContextXNN.cpp @@ -22,17 +22,7 @@ namespace webnn_native::xnnpack { - ContextBase* Create() { - Ref context = AcquireRef(new Context()); - xnn_status status = reinterpret_cast(context.Get())->Init(); - if (status != xnn_status_success) { - dawn::ErrorLog() << "Failed to init XNNPack:" << status; - return nullptr; - } - return context.Detach(); - } - - Context::Context() { + Context::Context(ContextOptions const* options) { } Context::~Context() { diff --git a/src/webnn_native/xnnpack/ContextXNN.h b/src/webnn_native/xnnpack/ContextXNN.h index 33dde3813..82b27b96a 100644 --- a/src/webnn_native/xnnpack/ContextXNN.h +++ b/src/webnn_native/xnnpack/ContextXNN.h @@ -23,7 +23,7 @@ namespace webnn_native::xnnpack { class Context : public ContextBase { public: - Context(); + explicit Context(ContextOptions const* options); ~Context() override; xnn_status Init(); diff --git a/src/webnn_native/xnnpack/GraphXNN.cpp b/src/webnn_native/xnnpack/GraphXNN.cpp index 8395c4dd5..2bffdf242 100644 --- a/src/webnn_native/xnnpack/GraphXNN.cpp +++ b/src/webnn_native/xnnpack/GraphXNN.cpp @@ -27,8 +27,6 @@ #define FAILED(status) (((xnn_status)(status)) != xnn_status_success) -#define XNNPACK_MAX_VALUE_ID 10000 - const char* xnn_status2str(xnn_status v) { if (v == xnn_status_success) return "success"; @@ -89,317 +87,261 @@ namespace webnn_native::xnnpack { } return xnn_status_success; } - - size_t SizeOfXnnDataType(xnn_datatype dataType) { - if (dataType == xnn_datatype_fp32) { - return sizeof(float); - } else if (dataType == xnn_datatype_fp16) { - return sizeof(uint16_t); - } else if (dataType == xnn_datatype_qint32) { - return sizeof(int32_t); - } else if (dataType == xnn_datatype_qint8) { - return sizeof(int8_t); - } - return 0; - } - - xnn_status BroadcastDimensions(const std::vector& aDims, - const std::vector& bDims, - std::vector& cDims) { - cDims.resize(std::max(aDims.size(), bDims.size())); - for (size_t i = 0; i < cDims.size(); ++i) { - size_t aDim = i < aDims.size() ? aDims[aDims.size() - i - 1] : 1; - size_t bDim = i < bDims.size() ? bDims[bDims.size() - i - 1] : 1; - size_t cIndex = cDims.size() - i - 1; - if (aDim == 1 && bDim != 1) { - cDims[cIndex] = bDim; - } else if (aDim != 1 && bDim == 1) { - cDims[cIndex] = aDim; - } else if (aDim == bDim) { - cDims[cIndex] = aDim; - } else { - return xnn_status_invalid_parameter; - } - } - return xnn_status_success; - } - - size_t GetEffectiveFilterSize(size_t filterSize, size_t dilation) { - if (dilation <= 1) { - return filterSize; - } - return filterSize + (filterSize - 1) * (dilation - 1); - } - - size_t ComputeConv2DOutputSize(size_t input, - size_t filter, - size_t padBegin, - size_t padEnd, - size_t stride, - size_t dilation) { - size_t effectiveFilter = GetEffectiveFilterSize(filter, dilation); - return (input - effectiveFilter + padBegin + padEnd) / stride + 1; - } } // anonymous namespace - Graph::Graph(Context* context) : GraphBase(context), mXnnOperator(nullptr) { + Graph::Graph(Context* context) : GraphBase(context), mExternalId(0), mRuntime(nullptr) { } Graph::~Graph() { - if (mXnnOperator) { - if (FAILED(xnn_delete_operator(mXnnOperator))) { - dawn::ErrorLog() << "xnn_delete_operator failed."; - } - } - } - - MaybeError Graph::AddConstant(const op::Constant* constant) { - std::shared_ptr info = std::make_shared(OperandType::CONSTANT); - const OperandDescriptor* desc = constant->GetOperandDescriptor(); - DAWN_TRY(GetXnnDataType(desc->type, info->dataType)); - info->dims.assign(desc->dimensions, desc->dimensions + desc->dimensionsCount); - info->buffer.reset(new char[constant->GetByteLength()]); - if (info->buffer.get() == nullptr) { - return DAWN_OUT_OF_MEMORY_ERROR(""); - } - memcpy(info->buffer.get(), constant->GetBuffer(), constant->GetByteLength()); - mConstants.push_back(info); - mOperandInfoMap.insert(std::make_pair(constant, info)); - return {}; } MaybeError Graph::AddInput(const op::Input* input) { - std::shared_ptr info = std::make_shared(OperandType::INPUT); - const OperandDescriptor* desc = input->GetOperandDescriptor(); - DAWN_TRY(GetXnnDataType(desc->type, info->dataType)); - info->dims.assign(desc->dimensions, desc->dimensions + desc->dimensionsCount); - info->name = input->GetName(); - mOperandInfoMap.insert(std::make_pair(input, info)); + mOperators.push_back({OperatorType::Input, input}); + uint32_t inputId = mExternalId++; + mInputs.insert(std::make_pair(input->PrimaryOutput(), inputId)); + mExternals.insert(std::make_pair(input->GetName(), inputId)); return {}; } MaybeError Graph::AddOutput(std::string_view name, const OperandBase* op) { - std::shared_ptr& info = mOperandInfoMap.at(op); - if (info->opType == OperandType::INPUT || info->opType == OperandType::CONSTANT) { - return DAWN_INTERNAL_ERROR("There is no operator to be created."); - } - info->name = name.data(); + uint32_t outputId = mExternalId++; + mOutputs.insert(std::make_pair(op, outputId)); + mExternals.insert(std::make_pair(name, outputId)); return {}; } - MaybeError Graph::AddBinary(const op::Binary* binary) { - std::shared_ptr info = std::make_shared(OperandType::BINARY); - mOperandInfoMap.insert(std::make_pair(binary, info)); - mOperandsToBuild.push_back(binary); - return {}; +#define GRAPH_ADD_OP(OpType) \ + MaybeError Graph::Add##OpType(const op::OpType* op) { \ + mOperators.push_back({OperatorType::OpType, op}); \ + return {}; \ } - MaybeError Graph::AddClamp(const op::Clamp* clamp) { - std::shared_ptr info = std::make_shared(OperandType::CLAMP); - mOperandInfoMap.insert(std::make_pair(clamp, info)); - mOperandsToBuild.push_back(clamp); - return {}; - } - - MaybeError Graph::AddConv2d(const op::Conv2d* conv2d) { - std::shared_ptr info = std::make_shared(OperandType::CONV2D); - mOperandInfoMap.insert(std::make_pair(conv2d, info)); - mOperandsToBuild.push_back(conv2d); - return {}; - } - - MaybeError Graph::AddPool2d(const op::Pool2d* pool2d) { - std::shared_ptr info = std::make_shared(OperandType::POOL2D); - mOperandInfoMap.insert(std::make_pair(pool2d, info)); - mOperandsToBuild.push_back(pool2d); - return {}; - } - - MaybeError Graph::AddUnary(const op::Unary* unary) { - std::shared_ptr info = std::make_shared(OperandType::UNARY); - mOperandInfoMap.insert(std::make_pair(unary, info)); - mOperandsToBuild.push_back(unary); - return {}; - } - - MaybeError Graph::Finish() { - if (mOperandsToBuild.size() == 0) { - return DAWN_INTERNAL_ERROR("No operators to build."); - } - const OperandBase* op = mOperandsToBuild[0]; - DAWN_ASSERT(mOperandInfoMap.find(op) != mOperandInfoMap.end()); - std::shared_ptr& info = mOperandInfoMap.at(op); - if (mOperandsToBuild.size() == 1) { - if (info->opType == OperandType::UNARY) { - DAWN_TRY(CreateXnnOp(reinterpret_cast(op))); - } else if (info->opType == OperandType::CLAMP) { - DAWN_TRY(CreateXnnOp(reinterpret_cast(op))); - } else if (info->opType == OperandType::BINARY) { - DAWN_TRY(CreateXnnOp(reinterpret_cast(op))); - } else if (info->opType == OperandType::CONV2D) { - DAWN_TRY(CreateXnnOp(reinterpret_cast(op))); - } else if (info->opType == OperandType::POOL2D) { - DAWN_TRY(CreateXnnOp(reinterpret_cast(op))); - } else { - return DAWN_UNIMPLEMENTED_ERROR(""); - } - } else if (info->opType == OperandType::CONV2D) { - // Try to fuse add and clamp into conv2d - const op::Conv2d* conv2d = reinterpret_cast(op); - if (mOperandsToBuild.size() > 3) { - return DAWN_INTERNAL_ERROR("Cannot fuse conv2d subgraph with more than 3 ops."); - } - const op::Binary* add = nullptr; - const op::Clamp* clamp = nullptr; - for (auto& operand : mOperandsToBuild) { - DAWN_ASSERT(mOperandInfoMap.find(operand) != mOperandInfoMap.end()); - std::shared_ptr& operandInfo = mOperandInfoMap.at(operand); - if (operandInfo->opType == OperandType::BINARY && - reinterpret_cast(operand)->GetType() == - op::BinaryOpType::kAdd) { - add = reinterpret_cast(operand); - } else if (operandInfo->opType == OperandType::CLAMP) { - clamp = reinterpret_cast(operand); - } - } - if ((mOperandsToBuild.size() == 2 && !add && !clamp) || - (mOperandsToBuild.size() == 3 && (!add || !clamp))) { - return DAWN_INTERNAL_ERROR("Failed to fuse conv2d subgraph."); - } - DAWN_TRY(CreateXnnOp(conv2d, add, clamp)); + GRAPH_ADD_OP(Binary) + GRAPH_ADD_OP(Clamp) + GRAPH_ADD_OP(Concat) + GRAPH_ADD_OP(Conv2d) + GRAPH_ADD_OP(Constant) + GRAPH_ADD_OP(Gemm) + GRAPH_ADD_OP(Pad) + GRAPH_ADD_OP(Pool2d) + GRAPH_ADD_OP(Reshape) + GRAPH_ADD_OP(Split) + GRAPH_ADD_OP(Squeeze) + GRAPH_ADD_OP(Unary) + + xnn_status Graph::DefineXnnTensorValue(xnn_subgraph_t subgraph, + const OperandBase* operand, + uint32_t* id, + const void* data) { + xnn_datatype datatype = xnn_datatype_invalid; + if (GetXnnDataType(operand->Type(), datatype) != xnn_status_success) { + // Ignore the unsupproted data type, it may be used for attributes, such as padding + return xnn_status_success; } - return {}; - } - - xnn_status Graph::CreateXnnOp(const op::Unary* unary) { - DAWN_ASSERT(unary->Inputs().size() == 1); - const OperandBase* inputOperand = unary->Inputs()[0].Get(); - DAWN_ASSERT(mOperandInfoMap.find(inputOperand) != mOperandInfoMap.end()); - const std::shared_ptr& inputInfo = mOperandInfoMap.at(inputOperand); - mInputs.push_back(inputInfo); - if (inputInfo->opType == OperandType::INPUT) { - mExternalInputs.insert(std::make_pair(inputInfo->name, mInputs.size() - 1)); - } - if (unary->GetType() == op::UnaryOpType::kRelu) { - XNN_TRY(xnn_create_clamp_nc_f32(1, 1, 1, 0, +std::numeric_limits::infinity(), 0, - &mXnnOperator)); - mXnnOperatorType = XnnOpType::clamp_nc_f32; + std::vector dims; + for (auto& d : operand->Shape()) { + dims.push_back(static_cast(d)); + } + uint32_t flags = 0; + uint32_t externalId; + if (mInputs.find(operand) != mInputs.end()) { + externalId = mInputs.at(operand); + flags |= XNN_VALUE_FLAG_EXTERNAL_INPUT; + } else if (mOutputs.find(operand) != mOutputs.end()) { + externalId = mOutputs.at(operand); + flags |= XNN_VALUE_FLAG_EXTERNAL_OUTPUT; } else { - return xnn_status_unsupported_parameter; + externalId = XNN_INVALID_VALUE_ID; } - std::shared_ptr& outputInfo = mOperandInfoMap.at(unary); - outputInfo->dataType = inputInfo->dataType; - outputInfo->dims = inputInfo->dims; - mOutputs.push_back(outputInfo); - mExternalOutputs.insert(std::make_pair(outputInfo->name, mOutputs.size() - 1)); + XNN_TRY(xnn_define_tensor_value(subgraph, datatype, dims.size(), dims.data(), data, + externalId, flags, id)); + mOperands.insert(std::make_pair(operand, *id)); return xnn_status_success; } - xnn_status Graph::CreateXnnOp(const op::Clamp* clamp) { - const OperandBase* inputOperand = clamp->Inputs()[0].Get(); - DAWN_ASSERT(mOperandInfoMap.find(inputOperand) != mOperandInfoMap.end()); - const std::shared_ptr& inputInfo = mOperandInfoMap.at(inputOperand); - mInputs.push_back(inputInfo); - if (inputInfo->opType == OperandType::INPUT) { - mExternalInputs.insert(std::make_pair(inputInfo->name, mInputs.size() - 1)); - } - const ClampOptions* options = clamp->GetOptions(); - float minValue = -std::numeric_limits::infinity(); - if (options->minValue != nullptr) { - const std::shared_ptr& minInfo = mOperandInfoMap.at(options->minValue); - if (minInfo->opType != OperandType::CONSTANT) { - dawn::ErrorLog() << "XNNPACK only supports clamp by value."; - return xnn_status_invalid_parameter; - } - minValue = (reinterpret_cast(minInfo->buffer.get()))[0]; + xnn_status Graph::DefineXnnNode(xnn_subgraph_t subgraph, const op::Constant* constant) { + std::unique_ptr buffer(new char[constant->GetByteLength()]); + if (buffer.get() == nullptr) { + return xnn_status_out_of_memory; } - float maxValue = +std::numeric_limits::infinity(); - if (options->maxValue != nullptr) { - const std::shared_ptr& maxInfo = mOperandInfoMap.at(options->maxValue); - if (maxInfo->opType != OperandType::CONSTANT) { - dawn::ErrorLog() << "XNNPACK only supports clamp by value."; - return xnn_status_invalid_parameter; - } - maxValue = (reinterpret_cast(maxInfo->buffer.get()))[0]; - } - XNN_TRY(xnn_create_clamp_nc_f32(1, 1, 1, minValue, maxValue, 0, &mXnnOperator)); - mXnnOperatorType = XnnOpType::clamp_nc_f32; - std::shared_ptr& outputInfo = mOperandInfoMap.at(clamp); - outputInfo->dataType = inputInfo->dataType; - outputInfo->dims = inputInfo->dims; - mOutputs.push_back(outputInfo); - mExternalOutputs.insert(std::make_pair(outputInfo->name, mOutputs.size() - 1)); + memcpy(buffer.get(), constant->GetBuffer(), constant->GetByteLength()); + uint32_t id; + XNN_TRY(DefineXnnTensorValue(subgraph, constant->PrimaryOutput(), &id, buffer.get())); + mOperands.insert(std::make_pair(constant->PrimaryOutput(), id)); + mBuffers.push_back(std::move(buffer)); + return xnn_status_success; + } + + xnn_status Graph::DefineXnnNode(xnn_subgraph_t subgraph, const op::Input* input) { + DAWN_ASSERT(mInputs.find(input->PrimaryOutput()) != mInputs.end()); + uint32_t id; + XNN_TRY(DefineXnnTensorValue(subgraph, input->PrimaryOutput(), &id)); + mOperands.insert(std::make_pair(input->PrimaryOutput(), id)); return xnn_status_success; } - xnn_status Graph::CreateXnnOp(const op::Binary* binary) { + xnn_status Graph::DefineXnnNode(xnn_subgraph_t subgraph, const op::Binary* binary) { DAWN_ASSERT(binary->Inputs().size() == 2); const OperandBase* input0Operand = binary->Inputs()[0].Get(); - DAWN_ASSERT(mOperandInfoMap.find(input0Operand) != mOperandInfoMap.end()); - const std::shared_ptr& input0Info = mOperandInfoMap.at(input0Operand); - mInputs.push_back(input0Info); - if (input0Info->opType == OperandType::INPUT) { - mExternalInputs.insert(std::make_pair(input0Info->name, mInputs.size() - 1)); - } + DAWN_ASSERT(mOperands.find(input0Operand) != mOperands.end()); + uint32_t input0Id = mOperands.at(input0Operand); const OperandBase* input1Operand = binary->Inputs()[1].Get(); - DAWN_ASSERT(mOperandInfoMap.find(input1Operand) != mOperandInfoMap.end()); - const std::shared_ptr& input1Info = mOperandInfoMap.at(input1Operand); - mInputs.push_back(input1Info); - if (input1Info->opType == OperandType::INPUT) { - mExternalInputs.insert(std::make_pair(input1Info->name, mInputs.size() - 1)); - } + DAWN_ASSERT(mOperands.find(input1Operand) != mOperands.end()); + uint32_t input1Id = mOperands.at(input1Operand); + auto outputOperand = binary->PrimaryOutput(); + uint32_t outputId; + XNN_TRY(DefineXnnTensorValue(subgraph, outputOperand, &outputId)); const float outputMin = -std::numeric_limits::infinity(); const float outputMax = +std::numeric_limits::infinity(); - if (binary->GetType() == op::BinaryOpType::kAdd) { - XNN_TRY(xnn_create_add_nd_f32(outputMin, outputMax, 0, &mXnnOperator)); - mXnnOperatorType = XnnOpType::add_nd_f32; - } else if (binary->GetType() == op::BinaryOpType::kMul) { - XNN_TRY(xnn_create_multiply_nd_f32(outputMin, outputMax, 0, &mXnnOperator)); - mXnnOperatorType = XnnOpType::multiply_nd_f32; - } else if (binary->GetType() == op::BinaryOpType::kSub) { - XNN_TRY(xnn_create_subtract_nd_f32(outputMin, outputMax, 0, &mXnnOperator)); - mXnnOperatorType = XnnOpType::subtract_nd_f32; - } else { - return xnn_status_unsupported_parameter; + switch (binary->GetType()) { + case op::BinaryOpType::kAdd: + XNN_TRY(xnn_define_add2(subgraph, outputMin, outputMax, input0Id, input1Id, + outputId, 0)); + break; + case op::BinaryOpType::kDiv: + XNN_TRY(xnn_define_divide(subgraph, outputMin, outputMax, input0Id, input1Id, + outputId, 0)); + break; + case op::BinaryOpType::kMax: + XNN_TRY(xnn_define_maximum2(subgraph, input0Id, input1Id, outputId, 0)); + break; + case op::BinaryOpType::kMin: + XNN_TRY(xnn_define_minimum2(subgraph, input0Id, input1Id, outputId, 0)); + break; + case op::BinaryOpType::kMul: + XNN_TRY(xnn_define_multiply2(subgraph, outputMin, outputMax, input0Id, input1Id, + outputId, 0)); + break; + case op::BinaryOpType::kSub: + XNN_TRY(xnn_define_subtract(subgraph, outputMin, outputMax, input0Id, input1Id, + outputId, 0)); + break; + case op::BinaryOpType::kMatMul: + if (input1Operand->Shape().size() != 2) { + dawn::ErrorLog() << "XNNPACK backend only support 2D operand b of matmul."; + return xnn_status_invalid_parameter; + } + XNN_TRY(xnn_define_fully_connected(subgraph, outputMin, outputMax, input0Id, + input1Id, XNN_INVALID_VALUE_ID, outputId, + XNN_FLAG_TRANSPOSE_WEIGHTS)); + break; + default: + dawn::ErrorLog() << "XNNPACK backend doesn't support unary op " + << static_cast(binary->GetType()); + return xnn_status_unsupported_parameter; } - std::shared_ptr& outputInfo = mOperandInfoMap.at(binary); - outputInfo->dataType = input0Info->dataType; - XNN_TRY(BroadcastDimensions(input0Info->dims, input1Info->dims, outputInfo->dims)); - mOutputs.push_back(outputInfo); - mExternalOutputs.insert(std::make_pair(outputInfo->name, mOutputs.size() - 1)); return xnn_status_success; } - xnn_status Graph::CreateXnnOp(const op::Pool2d* pool2d) { - DAWN_ASSERT(pool2d->Inputs().size() == 1); - const OperandBase* inputOperand = pool2d->Inputs()[0].Get(); - DAWN_ASSERT(mOperandInfoMap.find(inputOperand) != mOperandInfoMap.end()); - const std::shared_ptr& inputInfo = mOperandInfoMap.at(inputOperand); - mInputs.push_back(inputInfo); - if (inputInfo->opType == OperandType::INPUT) { - mExternalInputs.insert(std::make_pair(inputInfo->name, mInputs.size() - 1)); - } - const Pool2dOptions* options = pool2d->GetOptions(); - if (options->layout != wnn::InputOperandLayout::Nhwc) { - dawn::ErrorLog() << "XNNPACK only supports input layout nhwc."; + xnn_status Graph::DefineXnnNode(xnn_subgraph_t subgraph, const op::Clamp* clamp) { + DAWN_ASSERT(clamp->Inputs().size() == 1); + auto inputOperand = clamp->Inputs()[0].Get(); + DAWN_ASSERT(mOperands.find(inputOperand) != mOperands.end()); + uint32_t inputId = mOperands.at(inputOperand); + auto outputOperand = clamp->PrimaryOutput(); + uint32_t outputId; + XNN_TRY(DefineXnnTensorValue(subgraph, outputOperand, &outputId)); + XNN_TRY(xnn_define_clamp(subgraph, clamp->GetMinValue(), clamp->GetMaxValue(), inputId, + outputId, 0)); + return xnn_status_success; + } + + xnn_status Graph::DefineXnnNode(xnn_subgraph_t subgraph, const op::Concat* concat) { + auto inputOperands = concat->Inputs(); + DAWN_ASSERT(inputOperands.size() >= 1); + if (inputOperands.size() > 4) { + dawn::ErrorLog() << "XNNPACK backend doesn't support concat inputs size " + << inputOperands.size(); return xnn_status_invalid_parameter; } + std::vector inputIds(inputOperands.size()); + for (size_t i = 0; i < inputOperands.size(); ++i) { + DAWN_ASSERT(mOperands.find(inputOperands[i].Get()) != mOperands.end()); + inputIds[i] = mOperands.at(inputOperands[i].Get()); + } + auto outputOperand = concat->PrimaryOutput(); + uint32_t outputId; + XNN_TRY(DefineXnnTensorValue(subgraph, outputOperand, &outputId)); + size_t axis = concat->GetAxis(); + switch (concat->Inputs().size()) { + case 2: + XNN_TRY( + xnn_define_concatenate2(subgraph, axis, inputIds[0], inputIds[1], outputId, 0)); + break; + case 3: + XNN_TRY(xnn_define_concatenate3(subgraph, axis, inputIds[0], inputIds[1], + inputIds[2], outputId, 0)); + break; + case 4: + XNN_TRY(xnn_define_concatenate4(subgraph, axis, inputIds[0], inputIds[1], + inputIds[2], inputIds[3], outputId, 0)); + break; + default: + dawn::ErrorLog() << "XNNPACK backend doesn't support concat inputs size " + << inputOperands.size(); + return xnn_status_invalid_parameter; + } + return xnn_status_success; + } + + xnn_status Graph::DefineXnnNode(xnn_subgraph_t subgraph, const op::Conv2d* conv2d) { + auto inputOperands = conv2d->Inputs(); + DAWN_ASSERT(inputOperands.size() == 2 || inputOperands.size() == 3); + auto inputOperand = inputOperands[0].Get(); + DAWN_ASSERT(mOperands.find(inputOperand) != mOperands.end()); + uint32_t inputId = mOperands.at(inputOperand); + auto filterOperand = inputOperands[1].Get(); + DAWN_ASSERT(mOperands.find(filterOperand) != mOperands.end()); + uint32_t filterId = mOperands.at(filterOperand); + uint32_t biasId = XNN_INVALID_VALUE_ID; + if (inputOperands.size() == 3) { + DAWN_ASSERT(mOperands.find(inputOperands[2].Get()) != mOperands.end()); + biasId = mOperands.at(inputOperands[2].Get()); + } + auto outputOperand = conv2d->PrimaryOutput(); + + const Conv2dOptions* options = conv2d->GetOptions(); + uint32_t groups = options->groups; uint32_t strideHeight = options->strides[0]; uint32_t strideWidth = options->strides[1]; uint32_t dilationHeight = options->dilations[0]; uint32_t dilationWidth = options->dilations[1]; - // nhwc - size_t inputHeight = inputInfo->dims[1]; - size_t inputWidth = inputInfo->dims[2]; - size_t channels = inputInfo->dims[3]; + size_t inputHeight, inputWidth; uint32_t filterHeight, filterWidth; - if (options->windowDimensions != nullptr) { - filterHeight = options->windowDimensions[0]; - filterWidth = options->windowDimensions[1]; + size_t inputChannels, outputChannels; + bool depthwise = false; + if (options->inputLayout == wnn::InputOperandLayout::Nhwc) { + inputHeight = inputOperand->Shape()[1]; + inputWidth = inputOperand->Shape()[2]; + inputChannels = inputOperand->Shape()[3]; + depthwise = (groups == inputChannels); + if (!depthwise) { + // For regular conv2d, xnn pack expects weights layed out like (ohwi): + // [groups * group_output_channels, kernel_height, kernel_width, + // group_input_channels] + if (options->filterLayout != wnn::Conv2dFilterOperandLayout::Ohwi) { + dawn::ErrorLog() + << "XNNPACK backend only supports filter layout ohwi for conv2d."; + return xnn_status_invalid_parameter; + } + } else { + // For depthwise conv2d, xnn pack expects weights layed out like (ihwo): + // [1, kernel_height, kernel_width, input_channels * depth_multiplier] + if (options->filterLayout != wnn::Conv2dFilterOperandLayout::Ihwo) { + dawn::ErrorLog() + << "XNNPACK backend only supports filter layout ihwo for depthwise conv2d."; + return xnn_status_invalid_parameter; + } + } + filterHeight = filterOperand->Shape()[1]; + filterWidth = filterOperand->Shape()[2]; + outputChannels = outputOperand->Shape()[3]; } else { - filterHeight = inputHeight; - filterWidth = inputWidth; + dawn::ErrorLog() << "XNNPACK backend only supports input layout nhwc."; + return xnn_status_invalid_parameter; } + size_t groupInputChannels = inputChannels / groups; + size_t groupOutputChannels = outputChannels / groups; size_t outputHeight, outputWidth; uint32_t padTop, padBottom, padLeft, padRight; @@ -409,10 +351,6 @@ namespace webnn_native::xnnpack { padBottom = options->padding[1]; padLeft = options->padding[2]; padRight = options->padding[3]; - outputHeight = ComputeConv2DOutputSize(inputHeight, filterHeight, padTop, padBottom, - strideHeight, dilationHeight); - outputWidth = ComputeConv2DOutputSize(inputWidth, filterWidth, padLeft, padRight, - strideWidth, dilationWidth); } else { outputHeight = ceil(inputHeight / strideHeight); outputWidth = ceil(inputWidth / strideWidth); @@ -435,120 +373,138 @@ namespace webnn_native::xnnpack { float outputMin = -std::numeric_limits::infinity(); float outputMax = +std::numeric_limits::infinity(); - const uint32_t flags = 0; - if (pool2d->GetType() == op::Pool2dType::kAveragePool2d) { - if (dilationHeight != 1 || dilationWidth != 1) { - dawn::ErrorLog() << "XNNPACK does not support dilation for averagePool2d."; - return xnn_status_invalid_parameter; + if (options->activation) { + switch (options->activation->GetFusionType()) { + case FusionType::Clamp: { + auto clamp = reinterpret_cast(options->activation); + outputMin = clamp->GetMinValue(); + outputMax = clamp->GetMaxValue(); + break; + } + case FusionType::Relu: + outputMin = 0.0f; + outputMax = std::numeric_limits::infinity(); + break; + default: + dawn::ErrorLog() << "XNNPACK backend doesn't support fused operator " + << static_cast(options->activation->GetFusionType()); + return xnn_status_invalid_parameter; } - XNN_TRY(xnn_create_average_pooling2d_nhwc_f32( - padTop, padRight, padBottom, padLeft, filterHeight, filterWidth, strideHeight, - strideWidth, channels, channels, channels, outputMin, outputMax, flags, - &mXnnOperator)); - mXnnOperatorType = XnnOpType::average_pooling2d_nhwc_f32; - } else if (pool2d->GetType() == op::Pool2dType::kMaxPool2d) { - XNN_TRY(xnn_create_max_pooling2d_nhwc_f32( - padTop, padRight, padBottom, padLeft, filterHeight, filterWidth, strideHeight, - strideWidth, dilationHeight, dilationWidth, channels, channels, channels, outputMin, - outputMax, flags, &mXnnOperator)); - mXnnOperatorType = XnnOpType::max_pooling2d_nhwc_f32; + } + uint32_t outputId; + XNN_TRY(DefineXnnTensorValue(subgraph, outputOperand, &outputId)); + if (depthwise) { + XNN_TRY(xnn_define_depthwise_convolution_2d( + subgraph, padTop, padRight, padBottom, padLeft, filterHeight, filterWidth, + strideHeight, strideWidth, dilationHeight, dilationWidth, 1, inputChannels, + outputMin, outputMax, inputId, filterId, biasId, outputId, 0)); } else { - dawn::ErrorLog() << "XNNPACK does not support l2Pool2d."; + XNN_TRY(xnn_define_convolution_2d(subgraph, padTop, padRight, padBottom, padLeft, + filterHeight, filterWidth, strideHeight, strideWidth, + dilationHeight, dilationWidth, groups, + groupInputChannels, groupOutputChannels, outputMin, + outputMax, inputId, filterId, biasId, outputId, 0)); + } + return xnn_status_success; + } + + xnn_status Graph::DefineXnnNode(xnn_subgraph_t subgraph, const op::Gemm* gemm) { + auto inputs = gemm->Inputs(); + DAWN_ASSERT(inputs.size() == 2 || inputs.size() == 3); + DAWN_ASSERT(mOperands.find(inputs[0].Get()) != mOperands.end()); + uint32_t inputId = mOperands.at(inputs[0].Get()); + DAWN_ASSERT(mOperands.find(inputs[1].Get()) != mOperands.end()); + uint32_t filterId = mOperands.at(inputs[1].Get()); + uint32_t biasId = XNN_INVALID_VALUE_ID; + if (inputs.size() == 3) { + DAWN_ASSERT(mOperands.find(inputs[2].Get()) != mOperands.end()); + biasId = mOperands.at(inputs[2].Get()); + } + const GemmOptions* options = gemm->GetOptions(); + if (fabs(options->alpha - 1.0f) > std::numeric_limits::epsilon()) { + dawn::ErrorLog() << "XNNPACK backend doesn't support alpha " << options->alpha; return xnn_status_invalid_parameter; } - const std::shared_ptr outputInfo = mOperandInfoMap.at(pool2d); - outputInfo->dataType = inputInfo->dataType; - size_t batchSize = inputInfo->dims[0]; - // nchw - outputInfo->dims = {batchSize, outputHeight, outputWidth, channels}; - mOutputs.push_back(outputInfo); - mExternalOutputs.insert(std::make_pair(outputInfo->name, mOutputs.size() - 1)); + if (fabs(options->beta - 1.0f) > std::numeric_limits::epsilon()) { + dawn::ErrorLog() << "XNNPACK backend doesn't support beta " << options->beta; + return xnn_status_invalid_parameter; + } + if (options->aTranspose) { + dawn::ErrorLog() << "XNNPACK backend doesn't support aTranspose."; + return xnn_status_invalid_parameter; + } + uint32_t flags = 0; + if (!options->bTranspose) { + flags = XNN_FLAG_TRANSPOSE_WEIGHTS; + } + auto outputOperand = gemm->PrimaryOutput(); + uint32_t outputId; + XNN_TRY(DefineXnnTensorValue(subgraph, outputOperand, &outputId)); + const float outputMin = -std::numeric_limits::infinity(); + const float outputMax = +std::numeric_limits::infinity(); + XNN_TRY(xnn_define_fully_connected(subgraph, outputMin, outputMax, inputId, filterId, + biasId, outputId, flags)); return xnn_status_success; } - xnn_status Graph::CreateXnnOp(const op::Conv2d* conv2d, - const op::Binary* add, - const op::Clamp* clamp) { - DAWN_ASSERT(conv2d->Inputs().size() == 2); - const OperandBase* inputOperand = conv2d->Inputs()[0].Get(); - DAWN_ASSERT(mOperandInfoMap.find(inputOperand) != mOperandInfoMap.end()); - const std::shared_ptr& inputInfo = mOperandInfoMap.at(inputOperand); - mInputs.push_back(inputInfo); - if (inputInfo->opType == OperandType::INPUT) { - mExternalInputs.insert(std::make_pair(inputInfo->name, mInputs.size() - 1)); - } - const OperandBase* filterOperand = conv2d->Inputs()[1].Get(); - DAWN_ASSERT(mOperandInfoMap.find(filterOperand) != mOperandInfoMap.end()); - const std::shared_ptr& filterInfo = mOperandInfoMap.at(filterOperand); - if (filterInfo->opType != OperandType::CONSTANT) { - dawn::ErrorLog() << "filter is not a constant."; + xnn_status Graph::DefineXnnNode(xnn_subgraph_t subgraph, const op::Pad* pad) { + auto inputOperands = pad->Inputs(); + DAWN_ASSERT(inputOperands.size() == 2); + auto inputOperand = inputOperands[0].Get(); + size_t inputRank = inputOperand->Shape().size(); + DAWN_ASSERT(mOperands.find(inputOperand) != mOperands.end()); + uint32_t inputId = mOperands.at(inputOperand); + const op::Constant* paddingConstant = + reinterpret_cast(inputOperands[1]->Operator()); + const PadOptions* options = pad->GetOptions(); + if (options->mode != wnn::PaddingMode::Constant) { + dawn::ErrorLog() << "XNNPACK backend doesn't support padding mode " + << static_cast(options->mode); return xnn_status_invalid_parameter; } - const float* filter = reinterpret_cast(filterInfo->buffer.get()); + float paddingValue = options->value; + std::vector startPaddingVector; + std::vector endPaddingVector; + const uint32_t* paddingData = static_cast(paddingConstant->GetBuffer()); + for (size_t i = 0; i < inputRank; ++i) { + startPaddingVector.push_back(paddingData[2 * i]); + endPaddingVector.push_back(paddingData[2 * i + 1]); + } + auto outputOperand = pad->PrimaryOutput(); + uint32_t outputId; + XNN_TRY(DefineXnnTensorValue(subgraph, outputOperand, &outputId)); + XNN_TRY(xnn_define_static_constant_pad(subgraph, startPaddingVector.data(), + endPaddingVector.data(), paddingValue, inputId, + outputId, 0)); + return xnn_status_success; + } - const Conv2dOptions* options = conv2d->GetOptions(); - uint32_t groups = options->groups; + xnn_status Graph::DefineXnnNode(xnn_subgraph_t subgraph, const op::Pool2d* pool2d) { + DAWN_ASSERT(pool2d->Inputs().size() == 1); + auto inputOperand = pool2d->Inputs()[0].Get(); + DAWN_ASSERT(mOperands.find(inputOperand) != mOperands.end()); + uint32_t inputId = mOperands.at(inputOperand); + const Pool2dOptions* options = pool2d->GetOptions(); + if (options->layout != wnn::InputOperandLayout::Nhwc) { + dawn::ErrorLog() << "XNNPACK only supports input layout nhwc."; + return xnn_status_invalid_parameter; + } uint32_t strideHeight = options->strides[0]; uint32_t strideWidth = options->strides[1]; uint32_t dilationHeight = options->dilations[0]; uint32_t dilationWidth = options->dilations[1]; - size_t inputHeight, inputWidth; + // nhwc + size_t inputHeight = inputOperand->Shape()[1]; + size_t inputWidth = inputOperand->Shape()[2]; uint32_t filterHeight, filterWidth; - size_t inputChannels, outputChannels; - if (options->inputLayout == wnn::InputOperandLayout::Nhwc) { - inputHeight = inputInfo->dims[1]; - inputWidth = inputInfo->dims[2]; - inputChannels = inputInfo->dims[3]; - if (groups != 1 && groups == inputChannels) { - // For depthwiseConv2d, xnn pack expects the weights layout hwio - // [filter_height, filter_width, input_channels, channel_multiplier] - if (options->filterLayout != wnn::FilterOperandLayout::Hwio) { - dawn::ErrorLog() - << "XNNPACK only supports filter layout hwio for depthwise conv2d."; - return xnn_status_invalid_parameter; - } - if (filterInfo->dims[2] != 1) { - dawn::ErrorLog() << "The filter layout is invalid."; - return xnn_status_invalid_parameter; - } - filterHeight = filterInfo->dims[0]; - filterWidth = filterInfo->dims[1]; - outputChannels = filterInfo->dims[3]; - } else { - // For regular conv2d, xnn pack expects weights layed out like: - // [output_channels, filter_height, filter_width, input_channels] - if (options->filterLayout != wnn::FilterOperandLayout::Ohwi) { - dawn::ErrorLog() << "XNNPACK only supports filter layout ohwi for conv2d."; - return xnn_status_invalid_parameter; - } - if (filterInfo->dims[3] != inputChannels) { - dawn::ErrorLog() << "The filter layout is invalid."; - return xnn_status_invalid_parameter; - } - outputChannels = filterInfo->dims[0]; - filterHeight = filterInfo->dims[1]; - filterWidth = filterInfo->dims[2]; - } - } else { - dawn::ErrorLog() << "XNNPACK only supports input layout nhwc."; - return xnn_status_invalid_parameter; - } - const size_t inputChannelStride = inputChannels; - const size_t outputChannelStride = outputChannels; - size_t groupInputChannels; - size_t groupOutputChannels; - uint32_t flags = 0; - if (groups == 1) { - groupInputChannels = inputChannels; - groupOutputChannels = outputChannels; - } else if (groups == inputChannels) { - groupInputChannels = 1; - groupOutputChannels = outputChannels / groups; - flags |= XNN_FLAG_DEPTHWISE_CONVOLUTION; + bool global = false; + if (options->windowDimensions != nullptr) { + filterHeight = options->windowDimensions[0]; + filterWidth = options->windowDimensions[1]; } else { - // FIXME(nhu): implement the grouped conv2d. - dawn::ErrorLog() << "Grouped conv2d is unimplemented."; - return xnn_status_unsupported_parameter; + filterHeight = inputHeight; + filterWidth = inputWidth; + global = true; } size_t outputHeight, outputWidth; @@ -559,10 +515,6 @@ namespace webnn_native::xnnpack { padBottom = options->padding[1]; padLeft = options->padding[2]; padRight = options->padding[3]; - outputHeight = ComputeConv2DOutputSize(inputHeight, filterHeight, padTop, padBottom, - strideHeight, dilationHeight); - outputWidth = ComputeConv2DOutputSize(inputWidth, filterWidth, padLeft, padRight, - strideWidth, dilationWidth); } else { outputHeight = ceil(inputHeight / strideHeight); outputWidth = ceil(inputWidth / strideWidth); @@ -583,89 +535,203 @@ namespace webnn_native::xnnpack { } } - const float* bias = nullptr; - if (add) { - DAWN_ASSERT(add->Inputs().size() == 2); - OperandBase* biasOperand = nullptr; - if (conv2d == add->Inputs()[0].Get()) { - biasOperand = add->Inputs()[1].Get(); - } else if (conv2d == add->Inputs()[1].Get()) { - biasOperand = add->Inputs()[0].Get(); - } else { - dawn::ErrorLog() << "The add is not fusable."; + auto outputOperand = pool2d->PrimaryOutput(); + uint32_t outputId; + XNN_TRY(DefineXnnTensorValue(subgraph, outputOperand, &outputId)); + float outputMin = -std::numeric_limits::infinity(); + float outputMax = +std::numeric_limits::infinity(); + const uint32_t flags = 0; + if (pool2d->GetType() == op::Pool2dType::kAveragePool2d) { + if (dilationHeight != 1 || dilationWidth != 1) { + dawn::ErrorLog() << "XNNPACK does not support dilation for averagePool2d."; return xnn_status_invalid_parameter; } - DAWN_ASSERT(mOperandInfoMap.find(biasOperand) != mOperandInfoMap.end()); - const std::shared_ptr& biasInfo = mOperandInfoMap.at(biasOperand); - if (biasInfo->opType != OperandType::CONSTANT) { - dawn::ErrorLog() << "bias is not a constant."; - return xnn_status_invalid_parameter; + if (global) { + XNN_TRY(xnn_define_global_average_pooling_2d(subgraph, outputMin, outputMax, + inputId, outputId, flags)); + } else { + XNN_TRY(xnn_define_average_pooling_2d( + subgraph, padTop, padRight, padBottom, padLeft, filterHeight, filterWidth, + strideHeight, strideWidth, outputMin, outputMax, inputId, outputId, flags)); } - if (biasInfo->dims.size() != 1 && biasInfo->dims[0] != outputChannels) { - dawn::ErrorLog() << "bias dimensions is invalid."; + } else if (pool2d->GetType() == op::Pool2dType::kMaxPool2d) { + XNN_TRY(xnn_define_max_pooling_2d(subgraph, padTop, padRight, padBottom, padLeft, + filterHeight, filterWidth, strideHeight, strideWidth, + dilationHeight, dilationWidth, outputMin, outputMax, + inputId, outputId, flags)); + } else { + dawn::ErrorLog() << "XNNPACK does not support l2Pool2d."; + return xnn_status_invalid_parameter; + } + return xnn_status_success; + } + + xnn_status Graph::DefineXnnNode(xnn_subgraph_t subgraph, const op::Reshape* reshape) { + DAWN_ASSERT(reshape->Inputs().size() == 1); + auto inputOperand = reshape->Inputs()[0].Get(); + DAWN_ASSERT(mOperands.find(inputOperand) != mOperands.end()); + uint32_t inputId = mOperands.at(inputOperand); + auto outputOperand = reshape->PrimaryOutput(); + std::vector newSizes; + for (auto& d : outputOperand->Shape()) { + newSizes.push_back(static_cast(d)); + } + if (newSizes.size() > XNN_MAX_TENSOR_DIMS) { + dawn::ErrorLog() << "XNNPACK backend doesn't new shape rank " << newSizes.size(); + return xnn_status_invalid_parameter; + } + uint32_t outputId; + XNN_TRY(DefineXnnTensorValue(subgraph, outputOperand, &outputId)); + XNN_TRY(xnn_define_static_reshape(subgraph, newSizes.size(), newSizes.data(), inputId, + outputId, 0)); + return xnn_status_success; + } + + xnn_status Graph::DefineXnnNode(xnn_subgraph_t subgraph, const op::Split* split) { + DAWN_ASSERT(split->Inputs().size() == 1); + auto inputOperand = split->Inputs()[0].Get(); + DAWN_ASSERT(mOperands.find(inputOperand) != mOperands.end()); + uint32_t inputId = mOperands.at(inputOperand); + if (split->GetSplits().size() != 1) { + dawn::ErrorLog() << "XNNPACK backend only supports even split."; + return xnn_status_invalid_parameter; + } + int32_t axis = split->GetAxis(); + size_t outputSize = split->Outputs().size(); + if (outputSize > 4) { + dawn::ErrorLog() << "XNNPACK backend doesn't support even split more than 4."; + return xnn_status_invalid_parameter; + } + std::vector outputIds(outputSize); + for (size_t i = 0; i < outputSize; ++i) { + uint32_t outputId; + auto outputOperand = split->Outputs()[i].Get(); + XNN_TRY(DefineXnnTensorValue(subgraph, outputOperand, &outputId)); + outputIds[i] = outputId; + } + switch (outputSize) { + case 2: + XNN_TRY( + xnn_define_even_split2(subgraph, axis, inputId, outputIds[0], outputIds[1], 0)); + break; + case 3: + XNN_TRY(xnn_define_even_split3(subgraph, axis, inputId, outputIds[0], outputIds[1], + outputIds[2], 0)); + break; + case 4: + XNN_TRY(xnn_define_even_split4(subgraph, axis, inputId, outputIds[0], outputIds[1], + outputIds[2], outputIds[3], 0)); + break; + default: + dawn::ErrorLog() << "XNNPACK backend doesn't support even split more than 4."; return xnn_status_invalid_parameter; - } - bias = reinterpret_cast(biasInfo->buffer.get()); } + return xnn_status_success; + } - float outputMin = -std::numeric_limits::infinity(); - float outputMax = +std::numeric_limits::infinity(); - if (clamp) { - if (add) { - if (add != clamp->Inputs()[0].Get()) { - dawn::ErrorLog() << "The clamp is not fusable."; - return xnn_status_invalid_parameter; - } - } else { - if (conv2d != clamp->Inputs()[0].Get()) { - dawn::ErrorLog() << "The clamp is not fusable."; - return xnn_status_invalid_parameter; - } - } - const ClampOptions* options = clamp->GetOptions(); - if (options->minValue != nullptr) { - const std::shared_ptr& minInfo = mOperandInfoMap.at(options->minValue); - if (minInfo->opType != OperandType::CONSTANT) { - dawn::ErrorLog() << "XNNPACK only supports clamp by value."; - return xnn_status_invalid_parameter; - } - outputMin = (reinterpret_cast(minInfo->buffer.get()))[0]; - } - if (options->maxValue != nullptr) { - const std::shared_ptr& maxInfo = mOperandInfoMap.at(options->maxValue); - if (maxInfo->opType != OperandType::CONSTANT) { - dawn::ErrorLog() << "XNNPACK only supports clamp by value."; - return xnn_status_invalid_parameter; - } - outputMax = (reinterpret_cast(maxInfo->buffer.get()))[0]; - } + xnn_status Graph::DefineXnnNode(xnn_subgraph_t subgraph, const op::Squeeze* squeeze) { + DAWN_ASSERT(squeeze->Inputs().size() == 1); + auto inputOperand = squeeze->Inputs()[0].Get(); + DAWN_ASSERT(mOperands.find(inputOperand) != mOperands.end()); + uint32_t inputId = mOperands.at(inputOperand); + auto outputOperand = squeeze->PrimaryOutput(); + std::vector newSizes; + for (auto& d : outputOperand->Shape()) { + newSizes.push_back(static_cast(d)); + } + if (newSizes.size() > XNN_MAX_TENSOR_DIMS) { + dawn::ErrorLog() << "XNNPACK backend doesn't new size rank " << newSizes.size(); + return xnn_status_invalid_parameter; } + uint32_t outputId; + XNN_TRY(DefineXnnTensorValue(subgraph, outputOperand, &outputId)); + XNN_TRY(xnn_define_static_reshape(subgraph, newSizes.size(), newSizes.data(), inputId, + outputId, 0)); + return xnn_status_success; + } - XNN_TRY(xnn_create_convolution2d_nhwc_f32( - padTop, padRight, padBottom, padLeft, filterHeight, filterWidth, strideHeight, - strideWidth, dilationHeight, dilationWidth, groups, groupInputChannels, - groupOutputChannels, inputChannelStride, outputChannelStride, filter, bias, outputMin, - outputMax, flags, &mXnnOperator)); - mXnnOperatorType = XnnOpType::convolution2d_nhwc_f32; - std::shared_ptr outputInfo; - if (clamp) { - outputInfo = mOperandInfoMap.at(clamp); - } else if (add) { - outputInfo = mOperandInfoMap.at(add); - } else { - outputInfo = mOperandInfoMap.at(conv2d); + xnn_status Graph::DefineXnnNode(xnn_subgraph_t subgraph, const op::Unary* unary) { + DAWN_ASSERT(unary->Inputs().size() == 1); + auto inputOperand = unary->Inputs()[0].Get(); + DAWN_ASSERT(mOperands.find(inputOperand) != mOperands.end()); + uint32_t inputId = mOperands.at(inputOperand); + auto outputOperand = unary->PrimaryOutput(); + uint32_t outputId; + XNN_TRY(DefineXnnTensorValue(subgraph, outputOperand, &outputId)); + switch (unary->GetType()) { + case op::UnaryOpType::kAbs: + XNN_TRY(xnn_define_abs(subgraph, inputId, outputId, 0)); + break; + case op::UnaryOpType::kCeil: + XNN_TRY(xnn_define_ceiling(subgraph, inputId, outputId, 0)); + break; + case op::UnaryOpType::kFloor: + XNN_TRY(xnn_define_floor(subgraph, inputId, outputId, 0)); + break; + case op::UnaryOpType::kHardSwish: + XNN_TRY(xnn_define_hardswish(subgraph, inputId, outputId, 0)); + break; + case op::UnaryOpType::kLeakyRelu: + XNN_TRY(xnn_define_leaky_relu( + subgraph, reinterpret_cast(unary)->GetAlpha(), inputId, + outputId, 0)); + break; + case op::UnaryOpType::kNeg: + XNN_TRY(xnn_define_negate(subgraph, inputId, outputId, 0)); + break; + case op::UnaryOpType::kRelu: + XNN_TRY(xnn_define_clamp(subgraph, 0.0f, std::numeric_limits::infinity(), + inputId, outputId, 0)); + break; + case op::UnaryOpType::kSigmoid: + XNN_TRY(xnn_define_sigmoid(subgraph, inputId, outputId, 0)); + break; + case op::UnaryOpType::kSoftmax: + XNN_TRY(xnn_define_softmax(subgraph, inputId, outputId, 0)); + break; + default: + dawn::ErrorLog() << "XNNPACK backend doesn't support unary op " + << static_cast(unary->GetType()); + return xnn_status_unsupported_parameter; } - outputInfo->dataType = inputInfo->dataType; - size_t batchSize = inputInfo->dims[0]; - outputInfo->dims = {batchSize, outputHeight, outputWidth, outputChannels}; - mOutputs.push_back(outputInfo); - mExternalOutputs.insert(std::make_pair(outputInfo->name, mOutputs.size() - 1)); return xnn_status_success; } - size_t Graph::SizeOfOperandInfo(const std::shared_ptr& info) { - return std::accumulate(info->dims.begin(), info->dims.end(), 1, std::multiplies()) * - SizeOfXnnDataType(info->dataType); +#define HANDLE_OP(OpType) \ + case OperatorType::OpType: { \ + DAWN_TRY(DefineXnnNode(subgraph, reinterpret_cast(info.op))); \ + break; \ + } + + MaybeError Graph::Finish() { + xnn_subgraph_t subgraph; + if (FAILED(xnn_create_subgraph(mExternals.size(), 0, &subgraph))) { + return DAWN_INTERNAL_ERROR("xnn_create_subgraph failed."); + } + for (auto const& info : mOperators) { + switch (info.type) { + HANDLE_OP(Binary) + HANDLE_OP(Clamp) + HANDLE_OP(Constant) + HANDLE_OP(Concat) + HANDLE_OP(Conv2d) + HANDLE_OP(Gemm) + HANDLE_OP(Input) + HANDLE_OP(Pad) + HANDLE_OP(Pool2d) + HANDLE_OP(Reshape) + HANDLE_OP(Split) + HANDLE_OP(Squeeze) + HANDLE_OP(Unary) + default: { + return DAWN_UNIMPLEMENTED_ERROR(""); + } + } + } + uint32_t flags = XNN_FLAG_YIELD_WORKERS; + DAWN_TRY(xnn_create_runtime_v2(subgraph, GetThreadpool(), flags, &mRuntime)); + DAWN_TRY(xnn_delete_subgraph(subgraph)); + return {}; } pthreadpool_t Graph::GetThreadpool() { @@ -677,107 +743,31 @@ namespace webnn_native::xnnpack { } MaybeError Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) { - std::vector inputBuffers(mInputs.size(), nullptr); - for (size_t i = 0; i < mInputs.size(); ++i) { - if (mInputs[i]->opType == OperandType::CONSTANT) { - inputBuffers[i] = mInputs[i]->buffer.get(); - } - } - for (auto& [name, input] : inputs->GetRecords()) { - if (mExternalInputs.find(name) == mExternalInputs.end()) { - return DAWN_INTERNAL_ERROR("Invalid parameters."); - } - size_t index = mExternalInputs.at(name); - auto& resource = input.resource.arrayBufferView; - inputBuffers[index] = static_cast(resource.buffer) + resource.byteOffset; - } - - std::vector outputNames; - for (auto& [name, _] : outputs->GetRecords()) { - outputNames.push_back(name); - } - - std::vector outputBuffers(mOutputs.size(), nullptr); - for (size_t i = 0; i < outputNames.size(); ++i) { - std::string outputName = outputNames[i]; - size_t outputIndex = mExternalOutputs.at(outputName); - const std::shared_ptr& outputInfo = mOutputs[outputIndex]; - std::vector dimensions(outputInfo->dims.begin(), outputInfo->dims.end()); - size_t bufferLength = SizeOfOperandInfo(outputInfo); - if (outputs->GetRecords().find(outputName) != outputs->GetRecords().end()) { - const ArrayBufferView* output = - outputs->GetRecords().at(outputName).arrayBufferView; - DAWN_ASSERT(output->byteLength >= bufferLength); - outputBuffers[outputIndex] = - static_cast(output->buffer) + output->byteOffset; + std::vector externalValues; + for (auto& input : inputs->GetRecords()) { + if (mExternals.find(input.first) == mExternals.end()) { + return DAWN_VALIDATION_ERROR("Invalid input."); } + xnn_external_value value = {}; + value.id = mExternals.at(input.first); + value.data = static_cast(input.second.resource.arrayBufferView.buffer) + + input.second.resource.arrayBufferView.byteOffset; + externalValues.push_back(value); } - if (mXnnOperatorType == XnnOpType::convolution2d_nhwc_f32 || - mXnnOperatorType == XnnOpType::average_pooling2d_nhwc_f32 || - mXnnOperatorType == XnnOpType::max_pooling2d_nhwc_f32) { - std::vector inputDims = mInputs[0]->dims; - if (!inputBuffers[0] || !outputBuffers[0]) { - return DAWN_INTERNAL_ERROR("Invalid parameters."); - } - const float* input = reinterpret_cast(inputBuffers[0]); - float* output = reinterpret_cast(outputBuffers[0]); - size_t batchSize = inputDims[0]; - size_t inputHeight = inputDims[1]; - size_t inputWidth = inputDims[2]; - if (mXnnOperatorType == XnnOpType::convolution2d_nhwc_f32) { - DAWN_TRY(xnn_setup_convolution2d_nhwc_f32(mXnnOperator, batchSize, inputHeight, - inputWidth, input, output, - GetThreadpool())); - } else if (mXnnOperatorType == XnnOpType::average_pooling2d_nhwc_f32) { - DAWN_TRY(xnn_setup_average_pooling2d_nhwc_f32(mXnnOperator, batchSize, inputHeight, - inputWidth, input, output, - GetThreadpool())); - } else if (mXnnOperatorType == XnnOpType::max_pooling2d_nhwc_f32) { - DAWN_TRY(xnn_setup_max_pooling2d_nhwc_f32(mXnnOperator, batchSize, inputHeight, - inputWidth, input, output, - GetThreadpool())); - } - } else if (mXnnOperatorType == XnnOpType::clamp_nc_f32) { - const std::shared_ptr& outputInfo = mOutputs[0]; - size_t batchSize = std::accumulate(outputInfo->dims.begin(), outputInfo->dims.end(), 1, - std::multiplies()); - if (!inputBuffers[0] || !outputBuffers[0]) { - return DAWN_INTERNAL_ERROR("Invalid parameters."); + for (auto& output : outputs->GetRecords()) { + if (mExternals.find(output.first) == mExternals.end()) { + return DAWN_VALIDATION_ERROR("Invalid output."); } - const float* input = reinterpret_cast(inputBuffers[0]); - float* output = reinterpret_cast(outputBuffers[0]); - DAWN_TRY( - xnn_setup_clamp_nc_f32(mXnnOperator, batchSize, input, output, GetThreadpool())); - } else if (mXnnOperatorType == XnnOpType::add_nd_f32 || - mXnnOperatorType == XnnOpType::multiply_nd_f32 || - mXnnOperatorType == XnnOpType::subtract_nd_f32) { - std::vector dims0 = mInputs[0]->dims; - std::vector dims1 = mInputs[1]->dims; - if (!inputBuffers[0] || !inputBuffers[1] || !outputBuffers[0]) { - return DAWN_INTERNAL_ERROR("Invalid parameters."); - } - const float* input0 = reinterpret_cast(inputBuffers[0]); - const float* input1 = reinterpret_cast(inputBuffers[1]); - float* output = reinterpret_cast(outputBuffers[0]); - if (mXnnOperatorType == XnnOpType::add_nd_f32) { - DAWN_TRY(xnn_setup_add_nd_f32(mXnnOperator, dims0.size(), dims0.data(), - dims1.size(), dims1.data(), input0, input1, output, - GetThreadpool())); - } else if (mXnnOperatorType == XnnOpType::multiply_nd_f32) { - DAWN_TRY(xnn_setup_multiply_nd_f32(mXnnOperator, dims0.size(), dims0.data(), - dims1.size(), dims1.data(), input0, input1, - output, GetThreadpool())); - } else if (mXnnOperatorType == XnnOpType::subtract_nd_f32) { - DAWN_TRY(xnn_setup_subtract_nd_f32(mXnnOperator, dims0.size(), dims0.data(), - dims1.size(), dims1.data(), input0, input1, - output, GetThreadpool())); - } - } else { - return DAWN_INTERNAL_ERROR("The operator is not supported."); + xnn_external_value value = {}; + value.id = mExternals.at(output.first); + value.data = static_cast(output.second.arrayBufferView.buffer) + + output.second.arrayBufferView.byteOffset; + externalValues.push_back(value); } - DAWN_TRY(xnn_run_operator(mXnnOperator, GetThreadpool())); + DAWN_TRY(xnn_setup_runtime(mRuntime, externalValues.size(), externalValues.data())); + DAWN_TRY(xnn_invoke_runtime(mRuntime)); return {}; } diff --git a/src/webnn_native/xnnpack/GraphXNN.h b/src/webnn_native/xnnpack/GraphXNN.h index fae674c26..632b1a2df 100644 --- a/src/webnn_native/xnnpack/GraphXNN.h +++ b/src/webnn_native/xnnpack/GraphXNN.h @@ -24,12 +24,17 @@ #include "webnn_native/Operand.h" #include "webnn_native/ops/Binary.h" #include "webnn_native/ops/Clamp.h" +#include "webnn_native/ops/Concat.h" #include "webnn_native/ops/Constant.h" #include "webnn_native/ops/Conv2d.h" +#include "webnn_native/ops/Gemm.h" #include "webnn_native/ops/Input.h" #include "webnn_native/ops/LeakyRelu.h" +#include "webnn_native/ops/Pad.h" #include "webnn_native/ops/Pool2d.h" #include "webnn_native/ops/Reshape.h" +#include "webnn_native/ops/Split.h" +#include "webnn_native/ops/Squeeze.h" #include "webnn_native/ops/Transpose.h" #include "webnn_native/ops/Unary.h" #include "webnn_native/xnnpack/ContextXNN.h" @@ -45,9 +50,15 @@ namespace webnn_native::xnnpack { virtual MaybeError AddInput(const op::Input* input) override; virtual MaybeError AddOutput(std::string_view name, const OperandBase* output) override; virtual MaybeError AddBinary(const op::Binary* binary) override; - virtual MaybeError AddClamp(const op::Clamp* clamp) override; + virtual MaybeError AddConcat(const op::Concat* concat) override; virtual MaybeError AddConv2d(const op::Conv2d* conv2d) override; + virtual MaybeError AddClamp(const op::Clamp* clamp) override; + virtual MaybeError AddGemm(const op::Gemm* gemm) override; + virtual MaybeError AddPad(const op::Pad* pad) override; virtual MaybeError AddPool2d(const op::Pool2d* pool2d) override; + virtual MaybeError AddReshape(const op::Reshape* reshape) override; + virtual MaybeError AddSplit(const op::Split* split) override; + virtual MaybeError AddSqueeze(const op::Squeeze* squeeze) override; virtual MaybeError AddUnary(const op::Unary* unary) override; virtual MaybeError Finish() override; @@ -55,50 +66,57 @@ namespace webnn_native::xnnpack { MaybeError CompileImpl() override; MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; - enum OperandType { INPUT, CONSTANT, BINARY, CLAMP, CONV2D, POOL2D, UNARY }; - struct OperandInfo { - OperandInfo(OperandType opType) : opType(opType) { - } - OperandType opType; - std::string name = ""; - xnn_datatype dataType = xnn_datatype_invalid; - std::vector dims = {}; - std::unique_ptr buffer = nullptr; - }; - pthreadpool_t GetThreadpool(); - size_t SizeOfOperandInfo(const std::shared_ptr& info); - xnn_status CreateBuffer(std::shared_ptr& info, - const void* data = nullptr, - size_t length = 0); - xnn_status CreateXnnOp(const op::Unary* unary); - xnn_status CreateXnnOp(const op::Clamp* clamp); - xnn_status CreateXnnOp(const op::Binary* binary); - xnn_status CreateXnnOp(const op::Pool2d* pool2d); - xnn_status CreateXnnOp(const op::Conv2d* conv2d, - const op::Binary* add = nullptr, - const op::Clamp* clamp = nullptr); - enum XnnOpType { - add_nd_f32, - clamp_nc_f32, - multiply_nd_f32, - subtract_nd_f32, - convolution2d_nhwc_f32, - average_pooling2d_nhwc_f32, - max_pooling2d_nhwc_f32 + xnn_status DefineXnnTensorValue(xnn_subgraph_t subgraph, + const OperandBase* operand, + uint32_t* id, + const void* data = nullptr); + xnn_status DefineXnnNode(xnn_subgraph_t subgraph, const op::Constant* constant); + xnn_status DefineXnnNode(xnn_subgraph_t subgraph, const op::Input* Input); + xnn_status DefineXnnNode(xnn_subgraph_t subgraph, const op::Binary* binary); + xnn_status DefineXnnNode(xnn_subgraph_t subgraph, const op::Clamp* clamp); + xnn_status DefineXnnNode(xnn_subgraph_t subgraph, const op::Concat* concat); + xnn_status DefineXnnNode(xnn_subgraph_t subgraph, const op::Conv2d* conv2d); + xnn_status DefineXnnNode(xnn_subgraph_t subgraph, const op::Gemm* gemm); + xnn_status DefineXnnNode(xnn_subgraph_t subgraph, const op::Pad* pad); + xnn_status DefineXnnNode(xnn_subgraph_t subgraph, const op::Pool2d* pool2d); + xnn_status DefineXnnNode(xnn_subgraph_t subgraph, const op::Reshape* reshape); + xnn_status DefineXnnNode(xnn_subgraph_t subgraph, const op::Split* split); + xnn_status DefineXnnNode(xnn_subgraph_t subgraph, const op::Squeeze* squeeze); + xnn_status DefineXnnNode(xnn_subgraph_t subgraph, const op::Unary* unary); + + enum OperatorType { + Binary, + Constant, + Clamp, + Concat, + Conv2d, + Input, + Gemm, + Pad, + Pool2d, + Reshape, + Split, + Squeeze, + Unary + }; + struct OperatorInfo { + OperatorInfo(OperatorType type, const OperatorBase* op) : type(type), op(op) { + } + OperatorType type; + const OperatorBase* op; }; - XnnOpType mXnnOperatorType; - xnn_operator_t mXnnOperator; - std::vector> mConstants; - std::vector> mInputs; - std::vector> mOutputs; - std::map mExternalInputs; - std::map mExternalOutputs; + std::vector mOperators; + std::map mOperands; + std::map mInputs; + std::map mOutputs; + uint32_t mExternalId; + + std::vector> mBuffers; + std::map mExternals; - // For graph building - std::vector mOperandsToBuild; - std::map> mOperandInfoMap; + xnn_runtime_t mRuntime; }; } // namespace webnn_native::xnnpack diff --git a/webnn.json b/webnn.json index a17a8ecd8..8c79f3f4d 100644 --- a/webnn.json +++ b/webnn.json @@ -73,7 +73,8 @@ {"value": 1, "name": "DirectML"}, {"value": 2, "name": "OpenVINO"}, {"value": 3, "name": "OneDNN"}, - {"value": 4, "name": "MLAS"} + {"value": 4, "name": "MLAS"}, + {"value": 5, "name": "XNNPACK"} ] }, "error type": {