From 75617dbffdc6cfe495213673977a0e55e5d1e95a Mon Sep 17 00:00:00 2001 From: lishaguo Date: Fri, 8 Apr 2022 10:03:43 -0400 Subject: [PATCH] [Base] make ComputeImpl return MaybeError --- examples/LeNet/Main.cpp | 4 +- examples/MobileNetV2/Main.cpp | 10 ++-- examples/ResNet/Main.cpp | 10 ++-- examples/SampleUtils.cpp | 6 +-- examples/SampleUtils.h | 19 +++---- examples/SqueezeNet/Main.cpp | 10 ++-- node/src/Graph.cpp | 4 +- src/tests/unittests/native/GraphMockTests.cpp | 6 --- src/tests/unittests/native/mocks/GraphMock.h | 2 +- src/webnn_native/Graph.cpp | 24 ++++----- src/webnn_native/Graph.h | 5 +- src/webnn_native/dml/GraphDML.cpp | 10 ++-- src/webnn_native/dml/GraphDML.h | 3 +- src/webnn_native/mlas/GraphMLAS.cpp | 10 ++-- src/webnn_native/mlas/GraphMLAS.h | 3 +- src/webnn_native/null/ContextNull.cpp | 4 +- src/webnn_native/null/ContextNull.h | 3 +- src/webnn_native/onednn/GraphDNNL.cpp | 15 +++--- src/webnn_native/onednn/GraphDNNL.h | 3 +- src/webnn_native/openvino/GraphIE.cpp | 25 ++++----- src/webnn_native/openvino/GraphIE.h | 3 +- src/webnn_native/xnnpack/GraphXNN.cpp | 52 +++++++++---------- src/webnn_native/xnnpack/GraphXNN.h | 3 +- src/webnn_wire/client/ClientDoers.cpp | 4 +- src/webnn_wire/client/Graph.cpp | 10 ++-- src/webnn_wire/client/Graph.h | 6 +-- src/webnn_wire/server/Server.h | 2 +- src/webnn_wire/server/ServerGraph.cpp | 8 +-- webnn.json | 13 +---- webnn_wire.json | 8 +-- 30 files changed, 114 insertions(+), 171 deletions(-) diff --git a/examples/LeNet/Main.cpp b/examples/LeNet/Main.cpp index ef29c03cc..5e5f69129 100644 --- a/examples/LeNet/Main.cpp +++ b/examples/LeNet/Main.cpp @@ -87,9 +87,7 @@ int main(int argc, const char* argv[]) { for (int i = 0; i < nIter; ++i) { std::chrono::time_point executionStartTime = std::chrono::high_resolution_clock::now(); - wnn::ComputeGraphStatus status = - utils::Compute(graph, {{"input", input}}, {{"output", result}}); - DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success); + utils::Compute(graph, {{"input", input}}, {{"output", result}}); executionTimeVector.push_back(std::chrono::high_resolution_clock::now() - executionStartTime); } diff --git a/examples/MobileNetV2/Main.cpp b/examples/MobileNetV2/Main.cpp index 3b0dc8c0d..09fc0871f 100644 --- a/examples/MobileNetV2/Main.cpp +++ b/examples/MobileNetV2/Main.cpp @@ -60,18 +60,14 @@ int main(int argc, const char* argv[]) { std::vector result(utils::SizeOfShape(mobilevetv2.mOutputShape)); // Do the first inference for warming up if nIter > 1. if (mobilevetv2.mNIter > 1) { - wnn::ComputeGraphStatus status = - utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); - DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success); + utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); } std::vector executionTime; for (int i = 0; i < mobilevetv2.mNIter; ++i) { std::chrono::time_point executionStartTime = std::chrono::high_resolution_clock::now(); - wnn::ComputeGraphStatus status = - utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); - DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success); + utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime); } @@ -79,4 +75,4 @@ int main(int argc, const char* argv[]) { utils::PrintExexutionTime(executionTime); utils::PrintResult(result, mobilevetv2.mLabelPath); dawn::InfoLog() << "Done."; -} \ No newline at end of file +} diff --git a/examples/ResNet/Main.cpp b/examples/ResNet/Main.cpp index c88ca276e..5b29f3679 100644 --- a/examples/ResNet/Main.cpp +++ b/examples/ResNet/Main.cpp @@ -60,18 +60,14 @@ int main(int argc, const char* argv[]) { std::vector result(utils::SizeOfShape(resnet.mOutputShape)); // Do the first inference for warming up if nIter > 1. if (resnet.mNIter > 1) { - wnn::ComputeGraphStatus status = - utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); - DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success); + utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); } std::vector executionTime; for (int i = 0; i < resnet.mNIter; ++i) { std::chrono::time_point executionStartTime = std::chrono::high_resolution_clock::now(); - wnn::ComputeGraphStatus status = - utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); - DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success); + utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime); } @@ -79,4 +75,4 @@ int main(int argc, const char* argv[]) { utils::PrintExexutionTime(executionTime); utils::PrintResult(result, resnet.mLabelPath); dawn::InfoLog() << "Done."; -} \ No newline at end of file +} diff --git a/examples/SampleUtils.cpp b/examples/SampleUtils.cpp index 315bad342..043ed6d48 100644 --- a/examples/SampleUtils.cpp +++ b/examples/SampleUtils.cpp @@ -282,9 +282,9 @@ namespace utils { return builder.Build(namedOperands); } - wnn::ComputeGraphStatus Compute(const wnn::Graph& graph, - const std::vector>& inputs, - const std::vector>& outputs) { + void Compute(const wnn::Graph& graph, + const std::vector>& inputs, + const std::vector>& outputs) { return Compute(graph, inputs, outputs); } diff --git a/examples/SampleUtils.h b/examples/SampleUtils.h index 05b3a40e6..af8d8e016 100644 --- a/examples/SampleUtils.h +++ b/examples/SampleUtils.h @@ -243,12 +243,11 @@ namespace utils { }; template - wnn::ComputeGraphStatus Compute(const wnn::Graph& graph, - const std::vector>& inputs, - const std::vector>& outputs) { + void Compute(const wnn::Graph& graph, + const std::vector>& inputs, + const std::vector>& outputs) { if (graph.GetHandle() == nullptr) { dawn::ErrorLog() << "The graph is invaild."; - return wnn::ComputeGraphStatus::Error; } // The `mlInputs` local variable to hold the input data util computing the graph. @@ -274,15 +273,13 @@ namespace utils { mlOutputs.push_back(resource); namedOutputs.Set(output.name.c_str(), &mlOutputs.back()); } - wnn::ComputeGraphStatus status = graph.Compute(namedInputs, namedOutputs); + graph.Compute(namedInputs, namedOutputs); DoFlush(); - - return status; } - wnn::ComputeGraphStatus Compute(const wnn::Graph& graph, - const std::vector>& inputs, - const std::vector>& outputs); + void Compute(const wnn::Graph& graph, + const std::vector>& inputs, + const std::vector>& outputs); template bool CheckValue(const std::vector& value, const std::vector& expectedValue) { @@ -335,4 +332,4 @@ namespace utils { const std::string& powerPreference = "default"); } // namespace utils -#endif // WEBNN_NATIVE_EXAMPLES_SAMPLE_UTILS_H_ \ No newline at end of file +#endif // WEBNN_NATIVE_EXAMPLES_SAMPLE_UTILS_H_ diff --git a/examples/SqueezeNet/Main.cpp b/examples/SqueezeNet/Main.cpp index 5b2d58260..6c03b56d6 100644 --- a/examples/SqueezeNet/Main.cpp +++ b/examples/SqueezeNet/Main.cpp @@ -60,18 +60,14 @@ int main(int argc, const char* argv[]) { std::vector result(utils::SizeOfShape(squeezenet.mOutputShape)); // Do the first inference for warming up if nIter > 1. if (squeezenet.mNIter > 1) { - wnn::ComputeGraphStatus status = - utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); - DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success); + utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); } std::vector executionTime; for (int i = 0; i < squeezenet.mNIter; ++i) { std::chrono::time_point executionStartTime = std::chrono::high_resolution_clock::now(); - wnn::ComputeGraphStatus status = - utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); - DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success); + utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime); } @@ -79,4 +75,4 @@ int main(int argc, const char* argv[]) { utils::PrintExexutionTime(executionTime); utils::PrintResult(result, squeezenet.mLabelPath); dawn::InfoLog() << "Done."; -} \ No newline at end of file +} diff --git a/node/src/Graph.cpp b/node/src/Graph.cpp index 588ba1efb..af4f09003 100644 --- a/node/src/Graph.cpp +++ b/node/src/Graph.cpp @@ -133,9 +133,9 @@ namespace node { for (auto& output : outputs) { namedOutputs.Set(output.first.data(), &output.second); } - wnn::ComputeGraphStatus status = mImpl.Compute(namedInputs, namedOutputs); + mImpl.Compute(namedInputs, namedOutputs); - return Napi::Number::New(info.Env(), static_cast(status)); + return Napi::Number::New(info.Env(), 0); } Napi::Object Graph::Initialize(Napi::Env env, Napi::Object exports) { diff --git a/src/tests/unittests/native/GraphMockTests.cpp b/src/tests/unittests/native/GraphMockTests.cpp index 5d10b574e..cab725da3 100644 --- a/src/tests/unittests/native/GraphMockTests.cpp +++ b/src/tests/unittests/native/GraphMockTests.cpp @@ -37,11 +37,5 @@ namespace webnn_native { namespace { EXPECT_TRUE(graphMock.Compile().IsSuccess()); } - TEST_F(GraphMockTests, Compute) { - EXPECT_CALL(graphMock, ComputeImpl).Times(1); - NamedInputsBase inputs; - NamedOutputsBase outputs; - EXPECT_TRUE(graphMock.Compute(&inputs, &outputs) == WNNComputeGraphStatus_Success); - } }} // namespace webnn_native:: diff --git a/src/tests/unittests/native/mocks/GraphMock.h b/src/tests/unittests/native/mocks/GraphMock.h index fdc76cca3..6e0e34504 100644 --- a/src/tests/unittests/native/mocks/GraphMock.h +++ b/src/tests/unittests/native/mocks/GraphMock.h @@ -56,7 +56,7 @@ namespace webnn_native { (override)); MOCK_METHOD(MaybeError, Finish, (), (override)); MOCK_METHOD(MaybeError, CompileImpl, (), (override)); - MOCK_METHOD(WNNComputeGraphStatus, + MOCK_METHOD(MaybeError, ComputeImpl, (NamedInputsBase * inputs, NamedOutputsBase* outputs), (override)); diff --git a/src/webnn_native/Graph.cpp b/src/webnn_native/Graph.cpp index fb2f1cc1c..a725ea9c4 100644 --- a/src/webnn_native/Graph.cpp +++ b/src/webnn_native/Graph.cpp @@ -32,9 +32,8 @@ namespace webnn_native { MaybeError CompileImpl() override { UNREACHABLE(); } - WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs, - NamedOutputsBase* outputs) override { - return WNNComputeGraphStatus_Error; + MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override { + return DAWN_INTERNAL_ERROR("fail to build graph!"); } }; } // namespace @@ -138,12 +137,14 @@ namespace webnn_native { return CompileImpl(); } - WNNComputeGraphStatus GraphBase::Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs) { + void GraphBase::Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs) { if (inputs == nullptr || outputs == nullptr) { - return WNNComputeGraphStatus_Error; + GetContext()->ConsumedError(DAWN_VALIDATION_ERROR("input or output is nullptr.")); } - return ComputeImpl(inputs, outputs); + if (GetContext()->ConsumedError(ComputeImpl(inputs, outputs))) { + dawn::ErrorLog() << "fail to compute graph"; + } } void GraphBase::ComputeAsync(NamedInputsBase* inputs, @@ -151,14 +152,11 @@ namespace webnn_native { WNNComputeAsyncCallback callback, void* userdata) { if (inputs == nullptr || outputs == nullptr) { - callback(WNNComputeGraphStatus_Error, "named inputs or outputs is empty.", userdata); + callback(WNNErrorType_Validation, "named inputs or outputs is empty.", userdata); } - // TODO: Get error message from implemenation, ComputeImpl should return MaybeError, - // which is tracked with issues-959. - WNNComputeGraphStatus status = ComputeImpl(inputs, outputs); - std::string messages = status != WNNComputeGraphStatus_Success ? "Failed to async compute" - : "Success async compute"; - callback(status, messages.c_str(), userdata); + std::unique_ptr errorData = ComputeImpl(inputs, outputs).AcquireError(); + callback(static_cast(ToWNNErrorType(errorData->GetType())), + const_cast(errorData->GetMessage().c_str()), userdata); } GraphBase::GraphBase(ContextBase* context, ObjectBase::ErrorTag tag) diff --git a/src/webnn_native/Graph.h b/src/webnn_native/Graph.h index adc259318..7aa444209 100644 --- a/src/webnn_native/Graph.h +++ b/src/webnn_native/Graph.h @@ -82,7 +82,7 @@ namespace webnn_native { virtual MaybeError Compile(); // Webnn API - WNNComputeGraphStatus Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs); + void Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs); void ComputeAsync(NamedInputsBase* inputs, NamedOutputsBase* outputs, WNNComputeAsyncCallback callback, @@ -93,8 +93,7 @@ namespace webnn_native { private: virtual MaybeError CompileImpl() = 0; - virtual WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs, - NamedOutputsBase* outputs) = 0; + virtual MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) = 0; }; } // namespace webnn_native diff --git a/src/webnn_native/dml/GraphDML.cpp b/src/webnn_native/dml/GraphDML.cpp index b443f05d3..d35369558 100644 --- a/src/webnn_native/dml/GraphDML.cpp +++ b/src/webnn_native/dml/GraphDML.cpp @@ -1814,13 +1814,12 @@ namespace webnn_native::dml { return {}; } - WNNComputeGraphStatus Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) { + MaybeError Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) { auto namedInputs = inputs->GetRecords(); for (auto& [name, inputBinding] : mInputBindingMap) { // All the inputs must be set. if (namedInputs.find(name) == namedInputs.end()) { - dawn::ErrorLog() << "The input must be set."; - return WNNComputeGraphStatus_Error; + return DAWN_INTERNAL_ERROR("All inputs must be set."); } auto& resource = namedInputs[name].resource; @@ -1882,8 +1881,7 @@ namespace webnn_native::dml { std::vector outputTensors; if (FAILED(mDevice->DispatchOperator(mCompiledModel->op.Get(), inputBindings, outputExpressions, outputTensors))) { - dawn::ErrorLog() << "Failed to dispatch operator."; - return WNNComputeGraphStatus_Error; + return DAWN_INTERNAL_ERROR("Fail to dispatch operator."); } for (size_t i = 0; i < outputNames.size(); ++i) { @@ -1903,7 +1901,7 @@ namespace webnn_native::dml { delete tensor; } #endif - return WNNComputeGraphStatus_Success; + return {}; } } // namespace webnn_native::dml diff --git a/src/webnn_native/dml/GraphDML.h b/src/webnn_native/dml/GraphDML.h index ba5cfbaba..d6fcc2a75 100644 --- a/src/webnn_native/dml/GraphDML.h +++ b/src/webnn_native/dml/GraphDML.h @@ -100,8 +100,7 @@ namespace webnn_native::dml { private: MaybeError CompileImpl() override; - WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs, - NamedOutputsBase* outputs) override; + MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; ::dml::Expression BindingConstant(DML_TENSOR_DATA_TYPE dmlTensorType, ::dml::TensorDimensions dmlTensorDims, diff --git a/src/webnn_native/mlas/GraphMLAS.cpp b/src/webnn_native/mlas/GraphMLAS.cpp index ad3fad36c..09c7ffa6a 100644 --- a/src/webnn_native/mlas/GraphMLAS.cpp +++ b/src/webnn_native/mlas/GraphMLAS.cpp @@ -978,13 +978,12 @@ namespace webnn_native::mlas { return {}; } - WNNComputeGraphStatus Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) { + MaybeError Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) { for (auto& [name, input] : inputs->GetRecords()) { Ref inputMemory = mInputs.at(name); auto& resource = input.resource.arrayBufferView; if (inputMemory->GetByteLength() < resource.byteLength) { - dawn::ErrorLog() << "The size of input memory is less than input buffer."; - return WNNComputeGraphStatus_Error; + return DAWN_INTERNAL_ERROR("The size of input memory is less than input buffer."); } memcpy(inputMemory->GetBuffer(), static_cast(resource.buffer) + resource.byteOffset, @@ -1005,13 +1004,12 @@ namespace webnn_native::mlas { Ref outputMemory = mOutputs.at(outputName); const ArrayBufferView& output = outputs->GetRecords().at(outputName).arrayBufferView; if (output.byteLength < outputMemory->GetByteLength()) { - dawn::ErrorLog() << "The size of output buffer is less than output memory."; - return WNNComputeGraphStatus_Error; + return DAWN_INTERNAL_ERROR("The size of output buffer is less than output memory."); } memcpy(static_cast(output.buffer) + output.byteOffset, outputMemory->GetBuffer(), output.byteLength); } - return WNNComputeGraphStatus_Success; + return {}; } } // namespace webnn_native::mlas diff --git a/src/webnn_native/mlas/GraphMLAS.h b/src/webnn_native/mlas/GraphMLAS.h index 80dd588a1..36a64ff98 100644 --- a/src/webnn_native/mlas/GraphMLAS.h +++ b/src/webnn_native/mlas/GraphMLAS.h @@ -55,8 +55,7 @@ namespace webnn_native::mlas { private: MaybeError CompileImpl() override; - WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs, - NamedOutputsBase* outputs) override; + MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; std::unordered_map> mInputs; std::unordered_map> mOutputs; diff --git a/src/webnn_native/null/ContextNull.cpp b/src/webnn_native/null/ContextNull.cpp index e888e2672..bda183c08 100644 --- a/src/webnn_native/null/ContextNull.cpp +++ b/src/webnn_native/null/ContextNull.cpp @@ -68,8 +68,8 @@ namespace webnn_native::null { return {}; } - WNNComputeGraphStatus Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) { - return WNNComputeGraphStatus_Success; + MaybeError Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) { + return {}; } MaybeError Graph::AddConstant(const op::Constant* constant) { diff --git a/src/webnn_native/null/ContextNull.h b/src/webnn_native/null/ContextNull.h index 768bd17e6..a175f691c 100644 --- a/src/webnn_native/null/ContextNull.h +++ b/src/webnn_native/null/ContextNull.h @@ -76,8 +76,7 @@ namespace webnn_native::null { private: MaybeError CompileImpl() override; - WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs, - NamedOutputsBase* outputs) override; + MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; }; } // namespace webnn_native::null diff --git a/src/webnn_native/onednn/GraphDNNL.cpp b/src/webnn_native/onednn/GraphDNNL.cpp index c79789389..93e9fa95f 100644 --- a/src/webnn_native/onednn/GraphDNNL.cpp +++ b/src/webnn_native/onednn/GraphDNNL.cpp @@ -989,20 +989,19 @@ namespace webnn_native::onednn { return {}; } - WNNComputeGraphStatus Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) { + MaybeError Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) { for (auto& [name, input] : inputs->GetRecords()) { dnnl_memory_t inputMemory = mInputMemoryMap.at(name); auto& resource = input.resource.arrayBufferView; - COMPUTE_TRY(dnnl_memory_set_data_handle_v2( + DAWN_TRY(dnnl_memory_set_data_handle_v2( inputMemory, static_cast(resource.buffer) + resource.byteOffset, mStream)); } for (auto op : mOperations) { - COMPUTE_TRY( - dnnl_primitive_execute(op.primitive, mStream, op.args.size(), op.args.data())); + DAWN_TRY(dnnl_primitive_execute(op.primitive, mStream, op.args.size(), op.args.data())); } - COMPUTE_TRY(dnnl_stream_wait(mStream)); + DAWN_TRY(dnnl_stream_wait(mStream)); std::vector outputNames; for (auto& [name, _] : outputs->GetRecords()) { @@ -1013,17 +1012,17 @@ namespace webnn_native::onednn { std::string outputName = outputNames[i]; dnnl_memory_t outputMemory = mOutputMemoryMap.at(outputName); const dnnl_memory_desc_t* outputMemoryDesc; - COMPUTE_TRY(GetMemoryDesc(outputMemory, &outputMemoryDesc)); + DAWN_TRY(GetMemoryDesc(outputMemory, &outputMemoryDesc)); size_t bufferLength = dnnl_memory_desc_get_size(outputMemoryDesc); void* outputBuffer = malloc(bufferLength); - COMPUTE_TRY(ReadFromMemory(outputBuffer, bufferLength, outputMemory)); + DAWN_TRY(ReadFromMemory(outputBuffer, bufferLength, outputMemory)); ArrayBufferView output = outputs->GetRecords().at(outputName).arrayBufferView; if (output.byteLength >= bufferLength) { memcpy(static_cast(output.buffer) + output.byteOffset, outputBuffer, bufferLength); } } - return WNNComputeGraphStatus_Success; + return {}; } dnnl_engine_t Graph::GetEngine() { diff --git a/src/webnn_native/onednn/GraphDNNL.h b/src/webnn_native/onednn/GraphDNNL.h index 4c1830324..deca7eb41 100644 --- a/src/webnn_native/onednn/GraphDNNL.h +++ b/src/webnn_native/onednn/GraphDNNL.h @@ -62,8 +62,7 @@ namespace webnn_native::onednn { dnnl_status_t BuildPrimitives(); MaybeError CompileImpl() override; - WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs, - NamedOutputsBase* outputs) override; + MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; dnnl_engine_t GetEngine(); dnnl_status_t GetMemoryDesc(dnnl_memory_t memory, const dnnl_memory_desc_t** desc); dnnl_status_t ReorderIfNeeded(const dnnl_memory_desc_t* srcDesc, diff --git a/src/webnn_native/openvino/GraphIE.cpp b/src/webnn_native/openvino/GraphIE.cpp index d4d09e896..8726dfe99 100644 --- a/src/webnn_native/openvino/GraphIE.cpp +++ b/src/webnn_native/openvino/GraphIE.cpp @@ -1313,31 +1313,27 @@ namespace webnn_native::ie { return {}; } - WNNComputeGraphStatus Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) { + MaybeError Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) { auto namedInputs = inputs->GetRecords(); for (auto& [name, input] : mInputIdMap) { // All the inputs must be set. if (namedInputs.find(name) == namedInputs.end()) { - dawn::ErrorLog() << "The input isn't set"; - return WNNComputeGraphStatus_Error; + return DAWN_VALIDATION_ERROR("The input isn't set"); } ie_blob_t* blob; char* inputName = nullptr; IEStatusCode status = ie_network_get_input_name(mInferEngineNetwork, input, &inputName); if (status != IEStatusCode::OK) { - dawn::ErrorLog() << "IE Failed to ie_network_get_input_name"; - return WNNComputeGraphStatus_Error; + return DAWN_INTERNAL_ERROR("IE Failed to ie_network_get_input_name"); } status = ie_infer_request_get_blob(mInferEngineRequest, inputName, &blob); if (status != IEStatusCode::OK) { - dawn::ErrorLog() << "IE Failed to ie_infer_request_get_blob"; - return WNNComputeGraphStatus_Error; + return DAWN_INTERNAL_ERROR("IE Failed to ie_infer_request_get_blob"); } ie_blob_buffer_t buffer; status = ie_blob_get_buffer(blob, &buffer); if (status != IEStatusCode::OK) { - dawn::ErrorLog() << "IE Failed to ie_blob_get_buffer"; - return WNNComputeGraphStatus_Error; + return DAWN_INTERNAL_ERROR("IE Failed to ie_blob_get_buffer"); } auto& resource = namedInputs[name].resource.arrayBufferView; memcpy(buffer.buffer, static_cast(resource.buffer) + resource.byteOffset, @@ -1347,8 +1343,7 @@ namespace webnn_native::ie { // Compute the compiled model. IEStatusCode code = ie_infer_request_infer(mInferEngineRequest); if (code != IEStatusCode::OK) { - dawn::ErrorLog() << "IE Failed to compute model"; - return WNNComputeGraphStatus_Error; + return DAWN_INTERNAL_ERROR("IE Failed to compute model"); } // Get Data from nGraph with output. @@ -1358,8 +1353,7 @@ namespace webnn_native::ie { // Get output id with friendly name. auto originalName = mOutputNameMap[name]; if (mOriginalNameMap.find(originalName) == mOriginalNameMap.end()) { - dawn::ErrorLog() << "IE Failed to compute model"; - return WNNComputeGraphStatus_Error; + return DAWN_INTERNAL_ERROR("IE Failed to get output"); } char* sinkingName; IEStatusCode status = ie_network_get_output_name( @@ -1367,8 +1361,7 @@ namespace webnn_native::ie { ie_blob_t* outputBlob; status = ie_infer_request_get_blob(mInferEngineRequest, sinkingName, &outputBlob); if (status != IEStatusCode::OK) { - dawn::ErrorLog() << "IE Failed to ie_infer_request_get_blob"; - return WNNComputeGraphStatus_Error; + return DAWN_INTERNAL_ERROR("IE Failed to ie_infer_request_get_blob"); } ie_blob_buffer_t outputBuffer; status = ie_blob_get_cbuffer(outputBlob, &outputBuffer); @@ -1380,6 +1373,6 @@ namespace webnn_native::ie { } } - return WNNComputeGraphStatus_Success; + return {}; } } // namespace webnn_native::ie diff --git a/src/webnn_native/openvino/GraphIE.h b/src/webnn_native/openvino/GraphIE.h index c7df54f7b..a8f24b87d 100644 --- a/src/webnn_native/openvino/GraphIE.h +++ b/src/webnn_native/openvino/GraphIE.h @@ -80,8 +80,7 @@ namespace webnn_native::ie { private: MaybeError CompileImpl() override; - WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs, - NamedOutputsBase* outputs) override; + MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; // Map the input name to IE internal input number. std::map mInputIdMap; diff --git a/src/webnn_native/xnnpack/GraphXNN.cpp b/src/webnn_native/xnnpack/GraphXNN.cpp index 80cf2efd1..747536bd6 100644 --- a/src/webnn_native/xnnpack/GraphXNN.cpp +++ b/src/webnn_native/xnnpack/GraphXNN.cpp @@ -701,7 +701,7 @@ namespace webnn_native::xnnpack { } for (auto& [name, input] : inputs->GetRecords()) { if (mExternalInputs.find(name) == mExternalInputs.end()) { - COMPUTE_ERROR("Invalid parameters."); + return DAWN_INTERNAL_ERROR("Invalid parameters."); } size_t index = mExternalInputs.at(name); auto& resource = input.resource.arrayBufferView; @@ -734,7 +734,7 @@ namespace webnn_native::xnnpack { mXnnOperatorType == XnnOpType::max_pooling2d_nhwc_f32) { std::vector inputDims = mInputs[0]->dims; if (!inputBuffers[0] || !outputBuffers[0]) { - COMPUTE_ERROR("Invalid parameters."); + return DAWN_INTERNAL_ERROR("Invalid parameters."); } const float* input = reinterpret_cast(inputBuffers[0]); float* output = reinterpret_cast(outputBuffers[0]); @@ -742,28 +742,28 @@ namespace webnn_native::xnnpack { size_t inputHeight = inputDims[1]; size_t inputWidth = inputDims[2]; if (mXnnOperatorType == XnnOpType::convolution2d_nhwc_f32) { - COMPUTE_TRY(xnn_setup_convolution2d_nhwc_f32(mXnnOperator, batchSize, inputHeight, - inputWidth, input, output, - GetThreadpool())); + DAWN_TRY(xnn_setup_convolution2d_nhwc_f32(mXnnOperator, batchSize, inputHeight, + inputWidth, input, output, + GetThreadpool())); } else if (mXnnOperatorType == XnnOpType::average_pooling2d_nhwc_f32) { - COMPUTE_TRY(xnn_setup_average_pooling2d_nhwc_f32(mXnnOperator, batchSize, - inputHeight, inputWidth, input, - output, GetThreadpool())); + DAWN_TRY(xnn_setup_average_pooling2d_nhwc_f32(mXnnOperator, batchSize, inputHeight, + inputWidth, input, output, + GetThreadpool())); } else if (mXnnOperatorType == XnnOpType::max_pooling2d_nhwc_f32) { - COMPUTE_TRY(xnn_setup_max_pooling2d_nhwc_f32(mXnnOperator, batchSize, inputHeight, - inputWidth, input, output, - GetThreadpool())); + DAWN_TRY(xnn_setup_max_pooling2d_nhwc_f32(mXnnOperator, batchSize, inputHeight, + inputWidth, input, output, + GetThreadpool())); } } else if (mXnnOperatorType == XnnOpType::clamp_nc_f32) { const std::shared_ptr& outputInfo = mOutputs[0]; size_t batchSize = std::accumulate(outputInfo->dims.begin(), outputInfo->dims.end(), 1, std::multiplies()); if (!inputBuffers[0] || !outputBuffers[0]) { - COMPUTE_ERROR("Invalid parameters."); + return DAWN_INTERNAL_ERROR("Invalid parameters."); } const float* input = reinterpret_cast(inputBuffers[0]); float* output = reinterpret_cast(outputBuffers[0]); - COMPUTE_TRY( + DAWN_TRY( xnn_setup_clamp_nc_f32(mXnnOperator, batchSize, input, output, GetThreadpool())); } else if (mXnnOperatorType == XnnOpType::add_nd_f32 || mXnnOperatorType == XnnOpType::multiply_nd_f32 || @@ -771,31 +771,31 @@ namespace webnn_native::xnnpack { std::vector dims0 = mInputs[0]->dims; std::vector dims1 = mInputs[1]->dims; if (!inputBuffers[0] || !inputBuffers[1] || !outputBuffers[0]) { - COMPUTE_ERROR("Invalid parameters."); + return DAWN_INTERNAL_ERROR("Invalid parameters."); } const float* input0 = reinterpret_cast(inputBuffers[0]); const float* input1 = reinterpret_cast(inputBuffers[1]); float* output = reinterpret_cast(outputBuffers[0]); if (mXnnOperatorType == XnnOpType::add_nd_f32) { - COMPUTE_TRY(xnn_setup_add_nd_f32(mXnnOperator, dims0.size(), dims0.data(), - dims1.size(), dims1.data(), input0, input1, output, - GetThreadpool())); + DAWN_TRY(xnn_setup_add_nd_f32(mXnnOperator, dims0.size(), dims0.data(), + dims1.size(), dims1.data(), input0, input1, output, + GetThreadpool())); } else if (mXnnOperatorType == XnnOpType::multiply_nd_f32) { - COMPUTE_TRY(xnn_setup_multiply_nd_f32(mXnnOperator, dims0.size(), dims0.data(), - dims1.size(), dims1.data(), input0, input1, - output, GetThreadpool())); + DAWN_TRY(xnn_setup_multiply_nd_f32(mXnnOperator, dims0.size(), dims0.data(), + dims1.size(), dims1.data(), input0, input1, + output, GetThreadpool())); } else if (mXnnOperatorType == XnnOpType::subtract_nd_f32) { - COMPUTE_TRY(xnn_setup_subtract_nd_f32(mXnnOperator, dims0.size(), dims0.data(), - dims1.size(), dims1.data(), input0, input1, - output, GetThreadpool())); + DAWN_TRY(xnn_setup_subtract_nd_f32(mXnnOperator, dims0.size(), dims0.data(), + dims1.size(), dims1.data(), input0, input1, + output, GetThreadpool())); } } else { - COMPUTE_ERROR("The operator is not supported."); + return DAWN_INTERNAL_ERROR("The operator is not supported."); } - COMPUTE_TRY(xnn_run_operator(mXnnOperator, GetThreadpool())); + DAWN_TRY(xnn_run_operator(mXnnOperator, GetThreadpool())); - return WNNComputeGraphStatus_Success; + return {}; } } // namespace webnn_native::xnnpack diff --git a/src/webnn_native/xnnpack/GraphXNN.h b/src/webnn_native/xnnpack/GraphXNN.h index 552e86ac4..fae674c26 100644 --- a/src/webnn_native/xnnpack/GraphXNN.h +++ b/src/webnn_native/xnnpack/GraphXNN.h @@ -53,8 +53,7 @@ namespace webnn_native::xnnpack { private: MaybeError CompileImpl() override; - WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs, - NamedOutputsBase* outputs) override; + MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; enum OperandType { INPUT, CONSTANT, BINARY, CLAMP, CONV2D, POOL2D, UNARY }; struct OperandInfo { diff --git a/src/webnn_wire/client/ClientDoers.cpp b/src/webnn_wire/client/ClientDoers.cpp index b449154d0..d68d335b1 100644 --- a/src/webnn_wire/client/ClientDoers.cpp +++ b/src/webnn_wire/client/ClientDoers.cpp @@ -38,9 +38,9 @@ namespace webnn_wire::client { bool Client::DoGraphComputeAsyncCallback(Graph* graph, uint64_t requestSerial, - WNNComputeGraphStatus status, + WNNErrorType type, const char* message) { - return graph->OnComputeAsyncCallback(requestSerial, status, message); + return graph->OnComputeAsyncCallback(requestSerial, type, message); } } // namespace webnn_wire::client diff --git a/src/webnn_wire/client/Graph.cpp b/src/webnn_wire/client/Graph.cpp index 9b783ea73..8a45f75d8 100644 --- a/src/webnn_wire/client/Graph.cpp +++ b/src/webnn_wire/client/Graph.cpp @@ -20,7 +20,7 @@ namespace webnn_wire::client { - WNNComputeGraphStatus Graph::Compute(WNNNamedInputs inputs, WNNNamedOutputs outputs) { + void Graph::Compute(WNNNamedInputs inputs, WNNNamedOutputs outputs) { NamedInputs* namedInputs = FromAPI(inputs); NamedOutputs* namedOutputs = FromAPI(outputs); @@ -30,8 +30,6 @@ namespace webnn_wire::client { cmd.outputsId = namedOutputs->id; client->SerializeCommand(cmd); - - return WNNComputeGraphStatus::WNNComputeGraphStatus_Success; } void Graph::ComputeAsync(WNNNamedInputs inputs, @@ -39,7 +37,7 @@ namespace webnn_wire::client { WNNComputeAsyncCallback callback, void* userdata) { if (client->IsDisconnected()) { - callback(WNNComputeGraphStatus_ContextLost, "WebNN context disconnected", userdata); + callback(WNNErrorType_DeviceLost, "WebNN context disconnected", userdata); return; } @@ -61,7 +59,7 @@ namespace webnn_wire::client { } bool Graph::OnComputeAsyncCallback(uint64_t requestSerial, - WNNComputeGraphStatus status, + WNNErrorType type, const char* message) { auto requestIt = mComputeAsyncRequests.find(requestSerial); if (requestIt == mComputeAsyncRequests.end()) { @@ -71,7 +69,7 @@ namespace webnn_wire::client { ComputeAsyncRequest request = std::move(requestIt->second); mComputeAsyncRequests.erase(requestIt); - request.callback(status, message, request.userdata); + request.callback(type, message, request.userdata); return true; } diff --git a/src/webnn_wire/client/Graph.h b/src/webnn_wire/client/Graph.h index 0368d2303..bb57f936a 100644 --- a/src/webnn_wire/client/Graph.h +++ b/src/webnn_wire/client/Graph.h @@ -28,14 +28,12 @@ namespace webnn_wire::client { public: using ObjectBase::ObjectBase; - WNNComputeGraphStatus Compute(WNNNamedInputs inputs, WNNNamedOutputs outputs); + void Compute(WNNNamedInputs inputs, WNNNamedOutputs outputs); void ComputeAsync(WNNNamedInputs inputs, WNNNamedOutputs outputs, WNNComputeAsyncCallback callback, void* userdata); - bool OnComputeAsyncCallback(uint64_t requestSerial, - WNNComputeGraphStatus status, - const char* message); + bool OnComputeAsyncCallback(uint64_t requestSerial, WNNErrorType type, const char* message); private: struct ComputeAsyncRequest { diff --git a/src/webnn_wire/server/Server.h b/src/webnn_wire/server/Server.h index d124596d6..44aaa2700 100644 --- a/src/webnn_wire/server/Server.h +++ b/src/webnn_wire/server/Server.h @@ -159,7 +159,7 @@ namespace webnn_wire::server { WNNErrorType type, const char* message); void OnGraphComputeAsyncCallback(ComputeAsyncUserdata* userdata, - WNNComputeGraphStatus status, + WNNErrorType type, const char* message); #include "webnn_wire/server/ServerPrototypes_autogen.inc" diff --git a/src/webnn_wire/server/ServerGraph.cpp b/src/webnn_wire/server/ServerGraph.cpp index 8b237a7c0..a783f7272 100644 --- a/src/webnn_wire/server/ServerGraph.cpp +++ b/src/webnn_wire/server/ServerGraph.cpp @@ -66,10 +66,10 @@ namespace webnn_wire::server { return true; } - void Server::OnGraphComputeAsyncCallback(ComputeAsyncUserdata* userdata, - WNNComputeGraphStatus status, + void Server::OnGraphComputeAsyncCallback(ComputeAsyncUserdata* userdata, + WNNErrorType type, const char* message) { - if (status == WNNComputeGraphStatus_Success) { + if (type == WNNErrorType_NoError) { WNNArrayBufferView arrayBuffer = {}; auto* namedOutputs = NamedOutputsObjects().Get(userdata->namedOutputsObjectID); mProcs.namedOutputsGet(namedOutputs->handle, 0, &arrayBuffer); @@ -87,7 +87,7 @@ namespace webnn_wire::server { ReturnGraphComputeAsyncCallbackCmd cmd; cmd.graph = userdata->graph; cmd.requestSerial = userdata->requestSerial; - cmd.status = status; + cmd.type = type; cmd.message = message; SerializeCommand(cmd); diff --git a/webnn.json b/webnn.json index 4734d0475..a5a434010 100644 --- a/webnn.json +++ b/webnn.json @@ -1056,19 +1056,10 @@ } ] }, - "compute graph status": { - "category": "enum", - "values": [ - {"value": 0, "name": "success"}, - {"value": 1, "name": "error"}, - {"value": 2, "name": "context lost"}, - {"value": 3, "name": "unknown"} - ] - }, "compute async callback": { "category": "callback", "args": [ - {"name": "status", "type": "compute graph status"}, + {"name": "type", "type": "error type"}, {"name": "message", "type": "char", "annotation": "const*", "length": "strlen"}, {"name": "userdata", "type": "void", "annotation": "*"} ] @@ -1078,7 +1069,7 @@ "methods": [ { "name": "compute", - "returns": "compute graph status", + "returns": "void", "args": [ {"name": "inputs", "type": "named inputs"}, {"name": "outputs", "type": "named outputs"} diff --git a/webnn_wire.json b/webnn_wire.json index dc377fe99..5703a5912 100644 --- a/webnn_wire.json +++ b/webnn_wire.json @@ -115,7 +115,7 @@ "graph compute async callback": [ { "name": "graph", "type": "ObjectHandle", "handle_type": "graph" }, { "name": "request serial", "type": "uint64_t" }, - { "name": "status", "type": "compute graph status" }, + { "name": "type", "type": "error type"}, { "name": "message", "type": "char", "annotation": "const*", "length": "strlen" } ] }, @@ -138,10 +138,10 @@ "NamedOutputsGet", "OperandArraySize", "OperatorArraySize", - "GraphComputeAsync" + "GraphComputeAsync", + "GraphCompute" ], "client_handwritten_commands": [ - "GraphCompute", "ContextPushErrorScope" ], "client_special_objects": [ @@ -158,4 +158,4 @@ "server_handwritten_commands": [], "server_reverse_lookup_objects": [] } -} \ No newline at end of file +}