Skip to content

Commit

Permalink
[Base] make ComputeImpl return MaybeError
Browse files Browse the repository at this point in the history
  • Loading branch information
lisa0314 authored and mingmingtasd committed Apr 18, 2022
1 parent 8b4e8a6 commit af19a53
Show file tree
Hide file tree
Showing 30 changed files with 113 additions and 215 deletions.
4 changes: 1 addition & 3 deletions examples/LeNet/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ int main(int argc, const char* argv[]) {
for (int i = 0; i < nIter; ++i) {
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
std::chrono::high_resolution_clock::now();
wnn::ComputeGraphStatus status =
utils::Compute(graph, {{"input", input}}, {{"output", result}});
DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success);
utils::Compute(graph, {{"input", input}}, {{"output", result}});
executionTimeVector.push_back(std::chrono::high_resolution_clock::now() -
executionStartTime);
}
Expand Down
10 changes: 3 additions & 7 deletions examples/MobileNetV2/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,19 @@ int main(int argc, const char* argv[]) {
std::vector<float> result(utils::SizeOfShape(mobilevetv2.mOutputShape));
// Do the first inference for warming up if nIter > 1.
if (mobilevetv2.mNIter > 1) {
wnn::ComputeGraphStatus status =
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success);
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
}

std::vector<TIME_TYPE> executionTime;
for (int i = 0; i < mobilevetv2.mNIter; ++i) {
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
std::chrono::high_resolution_clock::now();
wnn::ComputeGraphStatus status =
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success);
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime);
}

// Print the result.
utils::PrintExexutionTime(executionTime);
utils::PrintResult(result, mobilevetv2.mLabelPath);
dawn::InfoLog() << "Done.";
}
}
10 changes: 3 additions & 7 deletions examples/ResNet/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,19 @@ int main(int argc, const char* argv[]) {
std::vector<float> result(utils::SizeOfShape(resnet.mOutputShape));
// Do the first inference for warming up if nIter > 1.
if (resnet.mNIter > 1) {
wnn::ComputeGraphStatus status =
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success);
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
}

std::vector<TIME_TYPE> executionTime;
for (int i = 0; i < resnet.mNIter; ++i) {
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
std::chrono::high_resolution_clock::now();
wnn::ComputeGraphStatus status =
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success);
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime);
}

// Print the result.
utils::PrintExexutionTime(executionTime);
utils::PrintResult(result, resnet.mLabelPath);
dawn::InfoLog() << "Done.";
}
}
6 changes: 3 additions & 3 deletions examples/SampleUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,9 @@ namespace utils {
return builder.Build(namedOperands);
}

wnn::ComputeGraphStatus Compute(const wnn::Graph& graph,
const std::vector<NamedInput<float>>& inputs,
const std::vector<NamedOutput<float>>& outputs) {
void Compute(const wnn::Graph& graph,
const std::vector<NamedInput<float>>& inputs,
const std::vector<NamedOutput<float>>& outputs) {
return Compute<float>(graph, inputs, outputs);
}

Expand Down
19 changes: 8 additions & 11 deletions examples/SampleUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,11 @@ namespace utils {
};

template <typename T>
wnn::ComputeGraphStatus Compute(const wnn::Graph& graph,
const std::vector<NamedInput<T>>& inputs,
const std::vector<NamedOutput<T>>& outputs) {
void Compute(const wnn::Graph& graph,
const std::vector<NamedInput<T>>& inputs,
const std::vector<NamedOutput<T>>& outputs) {
if (graph.GetHandle() == nullptr) {
dawn::ErrorLog() << "The graph is invaild.";
return wnn::ComputeGraphStatus::Error;
}

// The `mlInputs` local variable to hold the input data util computing the graph.
Expand All @@ -275,15 +274,13 @@ namespace utils {
mlOutputs.push_back(resource);
namedOutputs.Set(output.name.c_str(), &mlOutputs.back());
}
wnn::ComputeGraphStatus status = graph.Compute(namedInputs, namedOutputs);
graph.Compute(namedInputs, namedOutputs);
DoFlush();

return status;
}

wnn::ComputeGraphStatus Compute(const wnn::Graph& graph,
const std::vector<NamedInput<float>>& inputs,
const std::vector<NamedOutput<float>>& outputs);
void Compute(const wnn::Graph& graph,
const std::vector<NamedInput<float>>& inputs,
const std::vector<NamedOutput<float>>& outputs);

template <class T>
bool CheckValue(const std::vector<T>& value, const std::vector<T>& expectedValue) {
Expand Down Expand Up @@ -336,4 +333,4 @@ namespace utils {
const std::string& powerPreference = "default");
} // namespace utils

#endif // WEBNN_NATIVE_EXAMPLES_SAMPLE_UTILS_H_
#endif // WEBNN_NATIVE_EXAMPLES_SAMPLE_UTILS_H_
10 changes: 3 additions & 7 deletions examples/SqueezeNet/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,19 @@ int main(int argc, const char* argv[]) {
std::vector<float> result(utils::SizeOfShape(squeezenet.mOutputShape));
// Do the first inference for warming up if nIter > 1.
if (squeezenet.mNIter > 1) {
wnn::ComputeGraphStatus status =
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success);
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
}

std::vector<TIME_TYPE> executionTime;
for (int i = 0; i < squeezenet.mNIter; ++i) {
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
std::chrono::high_resolution_clock::now();
wnn::ComputeGraphStatus status =
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
DAWN_ASSERT(status == wnn::ComputeGraphStatus::Success);
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime);
}

// Print the result.
utils::PrintExexutionTime(executionTime);
utils::PrintResult(result, squeezenet.mLabelPath);
dawn::InfoLog() << "Done.";
}
}
4 changes: 2 additions & 2 deletions node/src/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ namespace node {
for (auto& output : outputs) {
namedOutputs.Set(output.first.data(), &output.second);
}
wnn::ComputeGraphStatus status = mImpl.Compute(namedInputs, namedOutputs);
mImpl.Compute(namedInputs, namedOutputs);

return Napi::Number::New(info.Env(), static_cast<uint32_t>(status));
return Napi::Number::New(info.Env(), 0);
}

Napi::Object Graph::Initialize(Napi::Env env, Napi::Object exports) {
Expand Down
6 changes: 0 additions & 6 deletions src/tests/unittests/native/GraphMockTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,5 @@ namespace webnn_native { namespace {
EXPECT_TRUE(graphMock.Compile().IsSuccess());
}

TEST_F(GraphMockTests, Compute) {
EXPECT_CALL(graphMock, ComputeImpl).Times(1);
NamedInputsBase inputs;
NamedOutputsBase outputs;
EXPECT_TRUE(graphMock.Compute(&inputs, &outputs) == WNNComputeGraphStatus_Success);
}

}} // namespace webnn_native::
2 changes: 1 addition & 1 deletion src/tests/unittests/native/mocks/GraphMock.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ namespace webnn_native {
(override));
MOCK_METHOD(MaybeError, Finish, (), (override));
MOCK_METHOD(MaybeError, CompileImpl, (), (override));
MOCK_METHOD(WNNComputeGraphStatus,
MOCK_METHOD(MaybeError,
ComputeImpl,
(NamedInputsBase * inputs, NamedOutputsBase* outputs),
(override));
Expand Down
24 changes: 8 additions & 16 deletions src/webnn_native/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ namespace webnn_native {
MaybeError CompileImpl() override {
UNREACHABLE();
}
WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs,
NamedOutputsBase* outputs) override {
return WNNComputeGraphStatus_Error;
MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override {
return DAWN_INTERNAL_ERROR("fail to build graph!");
}
};
} // namespace
Expand Down Expand Up @@ -138,27 +137,20 @@ namespace webnn_native {
return CompileImpl();
}

WNNComputeGraphStatus GraphBase::Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
if (inputs == nullptr || outputs == nullptr) {
return WNNComputeGraphStatus_Error;
}

return ComputeImpl(inputs, outputs);
void GraphBase::Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
GetContext()->ConsumedError(ComputeImpl(inputs, outputs));
}

void GraphBase::ComputeAsync(NamedInputsBase* inputs,
NamedOutputsBase* outputs,
WNNComputeAsyncCallback callback,
void* userdata) {
if (inputs == nullptr || outputs == nullptr) {
callback(WNNComputeGraphStatus_Error, "named inputs or outputs is empty.", userdata);
callback(WNNErrorType_Validation, "named inputs or outputs is empty.", userdata);
}
// TODO: Get error message from implemenation, ComputeImpl should return MaybeError,
// which is tracked with issues-959.
WNNComputeGraphStatus status = ComputeImpl(inputs, outputs);
std::string messages = status != WNNComputeGraphStatus_Success ? "Failed to async compute"
: "Success async compute";
callback(status, messages.c_str(), userdata);
std::unique_ptr<ErrorData> errorData = ComputeImpl(inputs, outputs).AcquireError();
callback(static_cast<WNNErrorType>(ToWNNErrorType(errorData->GetType())),
const_cast<char*>(errorData->GetMessage().c_str()), userdata);
}

GraphBase::GraphBase(ContextBase* context, ObjectBase::ErrorTag tag)
Expand Down
5 changes: 2 additions & 3 deletions src/webnn_native/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ namespace webnn_native {
virtual MaybeError Compile();

// Webnn API
WNNComputeGraphStatus Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs);
void Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs);
void ComputeAsync(NamedInputsBase* inputs,
NamedOutputsBase* outputs,
WNNComputeAsyncCallback callback,
Expand All @@ -93,8 +93,7 @@ namespace webnn_native {

private:
virtual MaybeError CompileImpl() = 0;
virtual WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs,
NamedOutputsBase* outputs) = 0;
virtual MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) = 0;
};
} // namespace webnn_native

Expand Down
13 changes: 4 additions & 9 deletions src/webnn_native/dml/GraphDML.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1814,14 +1814,10 @@ namespace webnn_native::dml {
return {};
}

WNNComputeGraphStatus Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
MaybeError Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
auto namedInputs = inputs->GetRecords();
for (auto& [name, inputBinding] : mInputBindingMap) {
// All the inputs must be set.
if (namedInputs.find(name) == namedInputs.end()) {
dawn::ErrorLog() << "The input must be set.";
return WNNComputeGraphStatus_Error;
}
DAWN_INVALID_IF(namedInputs.find(name) == namedInputs.end(), "all inputs must be set.");

auto& resource = namedInputs[name].resource;
if (resource.arrayBufferView.buffer != nullptr) {
Expand Down Expand Up @@ -1882,8 +1878,7 @@ namespace webnn_native::dml {
std::vector<pydml::TensorData*> outputTensors;
if (FAILED(mDevice->DispatchOperator(mCompiledModel->op.Get(), inputBindings,
outputExpressions, outputTensors))) {
dawn::ErrorLog() << "Failed to dispatch operator.";
return WNNComputeGraphStatus_Error;
return DAWN_INTERNAL_ERROR("Fail to dispatch operator.");
}

for (size_t i = 0; i < outputNames.size(); ++i) {
Expand All @@ -1903,7 +1898,7 @@ namespace webnn_native::dml {
delete tensor;
}
#endif
return WNNComputeGraphStatus_Success;
return {};
}

} // namespace webnn_native::dml
3 changes: 1 addition & 2 deletions src/webnn_native/dml/GraphDML.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ namespace webnn_native::dml {

private:
MaybeError CompileImpl() override;
WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs,
NamedOutputsBase* outputs) override;
MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override;

::dml::Expression BindingConstant(DML_TENSOR_DATA_TYPE dmlTensorType,
::dml::TensorDimensions dmlTensorDims,
Expand Down
16 changes: 6 additions & 10 deletions src/webnn_native/mlas/GraphMLAS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -978,14 +978,12 @@ namespace webnn_native::mlas {
return {};
}

WNNComputeGraphStatus Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
MaybeError Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
for (auto& [name, input] : inputs->GetRecords()) {
Ref<Memory> inputMemory = mInputs.at(name);
auto& resource = input.resource.arrayBufferView;
if (inputMemory->GetByteLength() < resource.byteLength) {
dawn::ErrorLog() << "The size of input memory is less than input buffer.";
return WNNComputeGraphStatus_Error;
}
DAWN_INVALID_IF(inputMemory->GetByteLength() < resource.byteLength,
"The size of input memory is less than input buffer.");
memcpy(inputMemory->GetBuffer(),
static_cast<int8_t*>(resource.buffer) + resource.byteOffset,
resource.byteLength);
Expand All @@ -1004,14 +1002,12 @@ namespace webnn_native::mlas {
std::string outputName = outputNames[i];
Ref<Memory> outputMemory = mOutputs.at(outputName);
const ArrayBufferView& output = outputs->GetRecords().at(outputName).arrayBufferView;
if (output.byteLength < outputMemory->GetByteLength()) {
dawn::ErrorLog() << "The size of output buffer is less than output memory.";
return WNNComputeGraphStatus_Error;
}
DAWN_INVALID_IF(output.byteLength < outputMemory->GetByteLength(),
"The size of output buffer is less than output memory.");
memcpy(static_cast<int8_t*>(output.buffer) + output.byteOffset,
outputMemory->GetBuffer(), output.byteLength);
}
return WNNComputeGraphStatus_Success;
return {};
}

} // namespace webnn_native::mlas
3 changes: 1 addition & 2 deletions src/webnn_native/mlas/GraphMLAS.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ namespace webnn_native::mlas {

private:
MaybeError CompileImpl() override;
WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs,
NamedOutputsBase* outputs) override;
MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override;

std::unordered_map<std::string, Ref<Memory>> mInputs;
std::unordered_map<std::string, Ref<Memory>> mOutputs;
Expand Down
4 changes: 2 additions & 2 deletions src/webnn_native/null/ContextNull.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ namespace webnn_native::null {
return {};
}

WNNComputeGraphStatus Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
return WNNComputeGraphStatus_Success;
MaybeError Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
return {};
}

MaybeError Graph::AddConstant(const op::Constant* constant) {
Expand Down
3 changes: 1 addition & 2 deletions src/webnn_native/null/ContextNull.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ namespace webnn_native::null {

private:
MaybeError CompileImpl() override;
WNNComputeGraphStatus ComputeImpl(NamedInputsBase* inputs,
NamedOutputsBase* outputs) override;
MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override;
};

} // namespace webnn_native::null
Expand Down
Loading

0 comments on commit af19a53

Please sign in to comment.