Skip to content

Commit

Permalink
[Base] make ComputeImpl return MaybeError
Browse files Browse the repository at this point in the history
  • Loading branch information
lisa0314 committed Apr 12, 2022
1 parent a9ce6fc commit 6e1b690
Show file tree
Hide file tree
Showing 30 changed files with 92 additions and 137 deletions.
4 changes: 1 addition & 3 deletions examples/LeNet/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ int main(int argc, const char* argv[]) {
for (int i = 0; i < nIter; ++i) {
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
std::chrono::high_resolution_clock::now();
wnn::ComputeGraphStatus status =
utils::Compute(graph, {{"input", input}}, {{"output", result}});
DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success);
utils::Compute(graph, {{"input", input}}, {{"output", result}});
executionTimeVector.push_back(std::chrono::high_resolution_clock::now() -
executionStartTime);
}
Expand Down
10 changes: 3 additions & 7 deletions examples/MobileNetV2/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,19 @@ int main(int argc, const char* argv[]) {
std::vector<float> result(utils::SizeOfShape(mobilevetv2.mOutputShape));
// Do the first inference for warming up if nIter > 1.
if (mobilevetv2.mNIter > 1) {
wnn::ComputeGraphStatus status =
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success);
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
}

std::vector<TIME_TYPE> executionTime;
for (int i = 0; i < mobilevetv2.mNIter; ++i) {
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
std::chrono::high_resolution_clock::now();
wnn::ComputeGraphStatus status =
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success);
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime);
}

// Print the result.
utils::PrintExexutionTime(executionTime);
utils::PrintResult(result, mobilevetv2.mLabelPath);
dawn::InfoLog() << "Done.";
}
}
10 changes: 3 additions & 7 deletions examples/ResNet/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,19 @@ int main(int argc, const char* argv[]) {
std::vector<float> result(utils::SizeOfShape(resnet.mOutputShape));
// Do the first inference for warming up if nIter > 1.
if (resnet.mNIter > 1) {
wnn::ComputeGraphStatus status =
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success);
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
}

std::vector<TIME_TYPE> executionTime;
for (int i = 0; i < resnet.mNIter; ++i) {
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
std::chrono::high_resolution_clock::now();
wnn::ComputeGraphStatus status =
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success);
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime);
}

// Print the result.
utils::PrintExexutionTime(executionTime);
utils::PrintResult(result, resnet.mLabelPath);
dawn::InfoLog() << "Done.";
}
}
2 changes: 1 addition & 1 deletion examples/SampleUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ namespace utils {
return builder.Build(namedOperands);
}

wnn::ComputeGraphStatus Compute(const wnn::Graph& graph,
void Compute(const wnn::Graph& graph,
const std::vector<NamedInput<float>>& inputs,
const std::vector<NamedOutput<float>>& outputs) {
return Compute<float>(graph, inputs, outputs);
Expand Down
10 changes: 4 additions & 6 deletions examples/SampleUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,12 +243,11 @@ namespace utils {
};

template <typename T>
wnn::ComputeGraphStatus Compute(const wnn::Graph& graph,
void Compute(const wnn::Graph& graph,
const std::vector<NamedInput<T>>& inputs,
const std::vector<NamedOutput<T>>& outputs) {
if (graph.GetHandle() == nullptr) {
dawn::ErrorLog() << "The graph is invaild.";
return wnn::ComputeGraphStatus::Error;
}

// The `mlInputs` local variable to hold the input data util computing the graph.
Expand All @@ -274,13 +273,12 @@ namespace utils {
mlOutputs.push_back(resource);
namedOutputs.Set(output.name.c_str(), &mlOutputs.back());
}
wnn::ComputeGraphStatus status = graph.Compute(namedInputs, namedOutputs);
graph.Compute(namedInputs, namedOutputs);
DoFlush();

return status;
}

wnn::ComputeGraphStatus Compute(const wnn::Graph& graph,
void Compute(const wnn::Graph& graph,
const std::vector<NamedInput<float>>& inputs,
const std::vector<NamedOutput<float>>& outputs);

Expand Down Expand Up @@ -335,4 +333,4 @@ namespace utils {
const std::string& powerPreference = "default");
} // namespace utils

#endif // WEBNN_NATIVE_EXAMPLES_SAMPLE_UTILS_H_
#endif // WEBNN_NATIVE_EXAMPLES_SAMPLE_UTILS_H_
10 changes: 3 additions & 7 deletions examples/SqueezeNet/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,19 @@ int main(int argc, const char* argv[]) {
std::vector<float> result(utils::SizeOfShape(squeezenet.mOutputShape));
// Do the first inference for warming up if nIter > 1.
if (squeezenet.mNIter > 1) {
wnn::ComputeGraphStatus status =
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success);
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
}

std::vector<TIME_TYPE> executionTime;
for (int i = 0; i < squeezenet.mNIter; ++i) {
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
std::chrono::high_resolution_clock::now();
wnn::ComputeGraphStatus status =
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success);
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime);
}

// Print the result.
utils::PrintExexutionTime(executionTime);
utils::PrintResult(result, squeezenet.mLabelPath);
dawn::InfoLog() << "Done.";
}
}
4 changes: 2 additions & 2 deletions node/src/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ namespace node {
for (auto& output : outputs) {
namedOutputs.Set(output.first.data(), &output.second);
}
wnn::ComputeGraphStatus status = mImpl.Compute(namedInputs, namedOutputs);
mImpl.Compute(namedInputs, namedOutputs);

return Napi::Number::New(info.Env(), static_cast<uint32_t>(status));
return Napi::Number::New(info.Env(), 0);
}

Napi::Object Graph::Initialize(Napi::Env env, Napi::Object exports) {
Expand Down
6 changes: 0 additions & 6 deletions src/tests/unittests/native/GraphMockTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,5 @@ namespace webnn_native { namespace {
EXPECT_TRUE(graphMock.Compile().IsSuccess());
}

TEST_F(GraphMockTests, Compute) {
EXPECT_CALL(graphMock, ComputeImpl).Times(1);
NamedInputsBase inputs;
NamedOutputsBase outputs;
EXPECT_TRUE(graphMock.Compute(&inputs, &outputs) == WNNComputeGraphStatus_Success);
}

}} // namespace webnn_native::
2 changes: 1 addition & 1 deletion src/tests/unittests/native/mocks/GraphMock.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ namespace webnn_native {
(override));
MOCK_METHOD(MaybeError, Finish, (), (override));
MOCK_METHOD(MaybeError, CompileImpl, (), (override));
MOCK_METHOD(WNNComputeGraphStatus,
MOCK_METHOD(MaybeError,
ComputeImpl,
(NamedInputsBase * inputs, NamedOutputsBase* outputs),
(override));
Expand Down
22 changes: 10 additions & 12 deletions src/webnn_native/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ namespace webnn_native {
MaybeError CompileImpl() override {
UNREACHABLE();
}
WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs,
MaybeError ComputeImpl(NamedInputsBase* inputs,
NamedOutputsBase* outputs) override {
return WNNComputeGraphStatus_Error;
return DAWN_INTERNAL_ERROR("fail to build graph!");
}
};
} // namespace
Expand Down Expand Up @@ -138,27 +138,25 @@ namespace webnn_native {
return CompileImpl();
}

WNNComputeGraphStatus GraphBase::Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
void GraphBase::Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
if (inputs == nullptr || outputs == nullptr) {
return WNNComputeGraphStatus_Error;
GetContext()->ConsumedError(DAWN_VALIDATION_ERROR("input or output is nullptr."));
}

return ComputeImpl(inputs, outputs);
if (GetContext()->ConsumedError(ComputeImpl(inputs, outputs))) {
dawn::ErrorLog() << "fail to compute graph";
}
}

void GraphBase::ComputeAsync(NamedInputsBase* inputs,
NamedOutputsBase* outputs,
WNNComputeAsyncCallback callback,
void* userdata) {
if (inputs == nullptr || outputs == nullptr) {
callback(WNNComputeGraphStatus_Error, "named inputs or outputs is empty.", userdata);
callback(WNNErrorType_Validation, "named inputs or outputs is empty.", userdata);
}
// TODO: Get error message from implemenation, ComputeImpl should return MaybeError,
// which is tracked with issues-959.
WNNComputeGraphStatus status = ComputeImpl(inputs, outputs);
std::string messages = status != WNNComputeGraphStatus_Success ? "Failed to async compute"
: "Success async compute";
callback(status, messages.c_str(), userdata);
std::unique_ptr<ErrorData> errorData = ComputeImpl(inputs, outputs).AcquireError();
callback(static_cast<WNNErrorType>(ToWNNErrorType(errorData->GetType())), const_cast<char*>(errorData->GetMessage().c_str()), userdata);
}

GraphBase::GraphBase(ContextBase* context, ObjectBase::ErrorTag tag)
Expand Down
4 changes: 2 additions & 2 deletions src/webnn_native/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ namespace webnn_native {
virtual MaybeError Compile();

// Webnn API
WNNComputeGraphStatus Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs);
void Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs);
void ComputeAsync(NamedInputsBase* inputs,
NamedOutputsBase* outputs,
WNNComputeAsyncCallback callback,
Expand All @@ -93,7 +93,7 @@ namespace webnn_native {

private:
virtual MaybeError CompileImpl() = 0;
virtual WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs,
virtual MaybeError ComputeImpl(NamedInputsBase* inputs,
NamedOutputsBase* outputs) = 0;
};
} // namespace webnn_native
Expand Down
10 changes: 4 additions & 6 deletions src/webnn_native/dml/GraphDML.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1814,13 +1814,12 @@ namespace webnn_native::dml {
return {};
}

WNNComputeGraphStatus Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
MaybeError Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
auto namedInputs = inputs->GetRecords();
for (auto& [name, inputBinding] : mInputBindingMap) {
// All the inputs must be set.
if (namedInputs.find(name) == namedInputs.end()) {
dawn::ErrorLog() << "The input must be set.";
return WNNComputeGraphStatus_Error;
return DAWN_INTERNAL_ERROR("All inputs must be set.");
}

auto& resource = namedInputs[name].resource;
Expand Down Expand Up @@ -1882,8 +1881,7 @@ namespace webnn_native::dml {
std::vector<pydml::TensorData*> outputTensors;
if (FAILED(mDevice->DispatchOperator(mCompiledModel->op.Get(), inputBindings,
outputExpressions, outputTensors))) {
dawn::ErrorLog() << "Failed to dispatch operator.";
return WNNComputeGraphStatus_Error;
return DAWN_INTERNAL_ERROR("Fail to dispatch operator.");
}

for (size_t i = 0; i < outputNames.size(); ++i) {
Expand All @@ -1903,7 +1901,7 @@ namespace webnn_native::dml {
delete tensor;
}
#endif
return WNNComputeGraphStatus_Success;
return {};
}

} // namespace webnn_native::dml
2 changes: 1 addition & 1 deletion src/webnn_native/dml/GraphDML.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ namespace webnn_native::dml {

private:
MaybeError CompileImpl() override;
WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs,
MaybeError ComputeImpl(NamedInputsBase* inputs,
NamedOutputsBase* outputs) override;

::dml::Expression BindingConstant(DML_TENSOR_DATA_TYPE dmlTensorType,
Expand Down
10 changes: 4 additions & 6 deletions src/webnn_native/mlas/GraphMLAS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -978,13 +978,12 @@ namespace webnn_native::mlas {
return {};
}

WNNComputeGraphStatus Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
MaybeError Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
for (auto& [name, input] : inputs->GetRecords()) {
Ref<Memory> inputMemory = mInputs.at(name);
auto& resource = input.resource.arrayBufferView;
if (inputMemory->GetByteLength() < resource.byteLength) {
dawn::ErrorLog() << "The size of input memory is less than input buffer.";
return WNNComputeGraphStatus_Error;
return DAWN_INTERNAL_ERROR("The size of input memory is less than input buffer.");
}
memcpy(inputMemory->GetBuffer(),
static_cast<int8_t*>(resource.buffer) + resource.byteOffset,
Expand All @@ -1005,13 +1004,12 @@ namespace webnn_native::mlas {
Ref<Memory> outputMemory = mOutputs.at(outputName);
const ArrayBufferView& output = outputs->GetRecords().at(outputName).arrayBufferView;
if (output.byteLength < outputMemory->GetByteLength()) {
dawn::ErrorLog() << "The size of output buffer is less than output memory.";
return WNNComputeGraphStatus_Error;
return DAWN_INTERNAL_ERROR("The size of output buffer is less than output memory.");
}
memcpy(static_cast<int8_t*>(output.buffer) + output.byteOffset,
outputMemory->GetBuffer(), output.byteLength);
}
return WNNComputeGraphStatus_Success;
return {};
}

} // namespace webnn_native::mlas
2 changes: 1 addition & 1 deletion src/webnn_native/mlas/GraphMLAS.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ namespace webnn_native::mlas {

private:
MaybeError CompileImpl() override;
WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs,
MaybeError ComputeImpl(NamedInputsBase* inputs,
NamedOutputsBase* outputs) override;

std::unordered_map<std::string, Ref<Memory>> mInputs;
Expand Down
4 changes: 2 additions & 2 deletions src/webnn_native/null/ContextNull.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ namespace webnn_native::null {
return {};
}

WNNComputeGraphStatus Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
return WNNComputeGraphStatus_Success;
MaybeError Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
return {};
}

MaybeError Graph::AddConstant(const op::Constant* constant) {
Expand Down
2 changes: 1 addition & 1 deletion src/webnn_native/null/ContextNull.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ namespace webnn_native::null {

private:
MaybeError CompileImpl() override;
WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs,
MaybeError ComputeImpl(NamedInputsBase* inputs,
NamedOutputsBase* outputs) override;
};

Expand Down
14 changes: 7 additions & 7 deletions src/webnn_native/onednn/GraphDNNL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -989,20 +989,20 @@ namespace webnn_native::onednn {
return {};
}

WNNComputeGraphStatus Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
MaybeError Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
for (auto& [name, input] : inputs->GetRecords()) {
dnnl_memory_t inputMemory = mInputMemoryMap.at(name);
auto& resource = input.resource.arrayBufferView;
COMPUTE_TRY(dnnl_memory_set_data_handle_v2(
DAWN_TRY(dnnl_memory_set_data_handle_v2(
inputMemory, static_cast<int8_t*>(resource.buffer) + resource.byteOffset, mStream));
}

for (auto op : mOperations) {
COMPUTE_TRY(
DAWN_TRY(
dnnl_primitive_execute(op.primitive, mStream, op.args.size(), op.args.data()));
}

COMPUTE_TRY(dnnl_stream_wait(mStream));
DAWN_TRY(dnnl_stream_wait(mStream));

std::vector<std::string> outputNames;
for (auto& [name, _] : outputs->GetRecords()) {
Expand All @@ -1013,17 +1013,17 @@ namespace webnn_native::onednn {
std::string outputName = outputNames[i];
dnnl_memory_t outputMemory = mOutputMemoryMap.at(outputName);
const dnnl_memory_desc_t* outputMemoryDesc;
COMPUTE_TRY(GetMemoryDesc(outputMemory, &outputMemoryDesc));
DAWN_TRY(GetMemoryDesc(outputMemory, &outputMemoryDesc));
size_t bufferLength = dnnl_memory_desc_get_size(outputMemoryDesc);
void* outputBuffer = malloc(bufferLength);
COMPUTE_TRY(ReadFromMemory(outputBuffer, bufferLength, outputMemory));
DAWN_TRY(ReadFromMemory(outputBuffer, bufferLength, outputMemory));
ArrayBufferView output = outputs->GetRecords().at(outputName).arrayBufferView;
if (output.byteLength >= bufferLength) {
memcpy(static_cast<int8_t*>(output.buffer) + output.byteOffset, outputBuffer,
bufferLength);
}
}
return WNNComputeGraphStatus_Success;
return {};
}

dnnl_engine_t Graph::GetEngine() {
Expand Down
2 changes: 1 addition & 1 deletion src/webnn_native/onednn/GraphDNNL.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ namespace webnn_native::onednn {
dnnl_status_t BuildPrimitives();

MaybeError CompileImpl() override;
WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs,
MaybeError ComputeImpl(NamedInputsBase* inputs,
NamedOutputsBase* outputs) override;
dnnl_engine_t GetEngine();
dnnl_status_t GetMemoryDesc(dnnl_memory_t memory, const dnnl_memory_desc_t** desc);
Expand Down
Loading

0 comments on commit 6e1b690

Please sign in to comment.