Skip to content

Commit

Permalink
Merge pull request #254 from lisa0314/refactor
Browse files Browse the repository at this point in the history
[Base]Refactor WebNN to align with Dawn
  • Loading branch information
fujunwei authored May 12, 2022
2 parents 3276a97 + 5a6604b commit 57a8f15
Show file tree
Hide file tree
Showing 259 changed files with 1,225 additions and 1,326 deletions.
2 changes: 1 addition & 1 deletion BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import("scripts/webnn_overrides_with_defaults.gni")

group("all") {
testonly = true
deps = [ "src/tests:webnn_tests" ]
deps = [ "src/webnn/tests:webnn_tests" ]
if (webnn_standalone) {
deps += [ "examples:webnn_samples" ]
}
Expand Down
2 changes: 1 addition & 1 deletion DEPS
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ hooks = [
'name': 'download_dml_unpkg',
'pattern': '.',
'condition': 'checkout_win',
'action': ['python3', 'src/webnn_native/dml/deps/script/download_dml.py'],
'action': ['python3', 'src/webnn/native/dml/deps/script/download_dml.py'],
}
]

Expand Down
12 changes: 6 additions & 6 deletions examples/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ static_library("webnn_sample_utils") {

# Export all of these as public deps so that `gn check` allows includes
public_deps = [
"${webnn_root}/src/common",
"${webnn_root}/src/utils:webnn_utils",
"${webnn_root}/src/webnn:cpp",
"${webnn_root}/src/webnn:webnn_proc",
"${webnn_root}/src/webnn:webnncpp",
"${webnn_root}/src/webnn_native",
"${webnn_root}/src/webnn_wire",
"${webnn_root}/src/webnn/common",
"${webnn_root}/src/webnn/native:webnn_native",
"${webnn_root}/src/webnn/utils:webnn_utils",
"${webnn_root}/src/webnn/wire:webnn_wire",
]

defines = [
Expand All @@ -58,7 +58,7 @@ static_library("webnn_sample_utils") {
]
}

public_configs = [ "${webnn_root}/src/common:dawn_internal" ]
public_configs = [ "${webnn_root}/src/webnn/common:internal_config" ]
if (is_linux) {
public_configs += [ "//build/config//gcc:rpath_for_built_shared_libraries" ]
}
Expand Down
36 changes: 18 additions & 18 deletions examples/SampleUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@

#include "common/Assert.h"
#include "common/Log.h"
#include "utils/TerribleCommandBuffer.h"
#include "webnn_native/NamedInputs.h"
#include "webnn_native/NamedOperands.h"
#include "webnn_native/NamedOutputs.h"

#include "webnn/native/NamedInputs.h"
#include "webnn/native/NamedOperands.h"
#include "webnn/native/NamedOutputs.h"
#include "webnn/utils/TerribleCommandBuffer.h"
#include "webnn/wire/WireClient.h"
#include "webnn/wire/WireServer.h"

#include <webnn/native/WebnnNative.h>
#include <webnn/webnn.h>
#include <webnn/webnn_cpp.h>
#include <webnn/webnn_proc.h>
#include <webnn_native/WebnnNative.h>
#include <webnn_wire/WireClient.h>
#include <webnn_wire/WireServer.h>
#include <algorithm>
#include <cmath>
#include <fstream>
Expand All @@ -45,16 +45,16 @@ static CmdBufType cmdBufType = CmdBufType::Terrible;
#else
static CmdBufType cmdBufType = CmdBufType::None;
#endif // defined(WEBNN_ENABLE_WIRE)
static webnn_wire::WireServer* wireServer = nullptr;
static webnn_wire::WireClient* wireClient = nullptr;
static webnn::wire::WireServer* wireServer = nullptr;
static webnn::wire::WireClient* wireClient = nullptr;
static utils::TerribleCommandBuffer* c2sBuf = nullptr;
static utils::TerribleCommandBuffer* s2cBuf = nullptr;

static wnn::Instance clientInstance;
static std::unique_ptr<webnn_native::Instance> nativeInstance;
static std::unique_ptr<webnn::native::Instance> nativeInstance;
wnn::Context CreateCppContext(wnn::ContextOptions const* options) {
nativeInstance = std::make_unique<webnn_native::Instance>();
WebnnProcTable backendProcs = webnn_native::GetProcs();
nativeInstance = std::make_unique<webnn::native::Instance>();
WebnnProcTable backendProcs = webnn::native::GetProcs();
WNNContext backendContext = nativeInstance->CreateContext(options);
if (backendContext == nullptr) {
return wnn::Context();
Expand All @@ -73,18 +73,18 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) {
c2sBuf = new utils::TerribleCommandBuffer();
s2cBuf = new utils::TerribleCommandBuffer();

webnn_wire::WireServerDescriptor serverDesc = {};
webnn::wire::WireServerDescriptor serverDesc = {};
serverDesc.procs = &backendProcs;
serverDesc.serializer = s2cBuf;

wireServer = new webnn_wire::WireServer(serverDesc);
wireServer = new webnn::wire::WireServer(serverDesc);
c2sBuf->SetHandler(wireServer);

webnn_wire::WireClientDescriptor clientDesc = {};
webnn::wire::WireClientDescriptor clientDesc = {};
clientDesc.serializer = c2sBuf;

wireClient = new webnn_wire::WireClient(clientDesc);
procs = webnn_wire::client::GetProcs();
wireClient = new webnn::wire::WireClient(clientDesc);
procs = webnn::wire::client::GetProcs();
s2cBuf->SetHandler(wireClient);

#ifdef ENABLE_INJECT_CONTEXT
Expand Down
2 changes: 1 addition & 1 deletion examples/SqueezeNet/SqueezeNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,4 @@ const wnn::Operand SqueezeNet::LoadNHWC(const wnn::GraphBuilder& builder, bool s
const wnn::Operand reshape = builder.Reshape(averagePool2d, newShape.data(), newShape.size());
const wnn::Operand output = softmax ? builder.Softmax(reshape) : reshape;
return output;
}
}
4 changes: 2 additions & 2 deletions generator/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import("webnn_generator.gni")
# autogenerated sources there.
_stale_dirs = [
"webnn",
"webnn_native",
"webnn_wire",
"webnn/native",
"webnn/wire",
"mock",
"src",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@
//* See the License for the specific language governing permissions and
//* limitations under the License.

#include "webnn_native/webnn_platform.h"
#include "webnn_native/WebnnNative.h"
#include "webnn/native/webnn_platform.h"
#include "webnn/native/WebnnNative.h"

#include <algorithm>
#include <vector>

{% for type in by_category["object"] %}
{% if type.name.canonical_case() not in ["texture view"] %}
#include "webnn_native/{{type.name.CamelCase()}}.h"
#include "webnn/native/{{type.name.CamelCase()}}.h"
{% endif %}
{% endfor %}

namespace webnn_native {
namespace webnn::native {

namespace {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
//* See the License for the specific language governing permissions and
//* limitations under the License.

#include "webnn_native/ValidationUtils_autogen.h"
#include "webnn/native/ValidationUtils_autogen.h"

namespace webnn_native {
namespace webnn::native {

{% for type in by_category["enum"] %}
MaybeError Validate{{type.name.CamelCase()}}(wnn::{{as_cppType(type.name)}} value) {
Expand All @@ -41,4 +41,4 @@ namespace webnn_native {

{% endfor %}

} // namespace webnn_native
} // namespace webnn::native
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@

#include "webnn/webnn_cpp.h"

#include "webnn_native/Error.h"
#include "webnn/native/Error.h"

namespace webnn_native {
namespace webnn::native {

// Helper functions to check the value of enums and bitmasks
{% for type in by_category["enum"] + by_category["bitmask"] %}
MaybeError Validate{{type.name.CamelCase()}}(wnn::{{as_cppType(type.name)}} value);
{% endfor %}

} // namespace webnn_native
} // namespace webnn::native

#endif // BACKEND_VALIDATIONUTILS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
#define WEBNN_NATIVE_WEBNN_PLATFORM_AUTOGEN_H_

#include "webnn/webnn_cpp.h"
#include "webnn_native/Forward.h"
#include "webnn/native/Forward.h"

namespace webnn_native {
namespace webnn::native {

template <typename T>
struct EnumCount;
Expand All @@ -31,4 +31,4 @@ namespace webnn_native {
{% endfor %}
}

#endif // WEBNN_NATIVE_WEBNN_PLATFORM_AUTOGEN_H_
#endif // WEBNN_NATIVE_WEBNN_PLATFORM_AUTOGEN_H_
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
//* See the License for the specific language governing permissions and
//* limitations under the License.

#include "webnn_native/webnn_structs_autogen.h"
#include "webnn/native/webnn_structs_autogen.h"

namespace webnn_native {
namespace webnn::native {

{% for type in by_category["structure"] %}
{% set CppType = as_cppType(type.name) %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
#define WEBNN_NATIVE_WEBNN_STRUCTS_H_

#include "webnn/webnn_cpp.h"
#include "webnn_native/Forward.h"
#include "webnn/native/Forward.h"

namespace webnn_native {
namespace webnn::native {

{% macro render_cpp_default_value(member) -%}
{%- if member.annotation in ["*", "const*", "const*const*"] and member.optional -%}
Expand Down Expand Up @@ -64,6 +64,6 @@ namespace webnn_native {

{% endfor %}

} // namespace webnn_native
} // namespace webnn::native

#endif // WEBNN_NATIVE_WEBNN_STRUCTS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include "common/ityp_array.h"

namespace webnn_wire {
namespace webnn::wire {

enum class ObjectType : uint32_t {
{% for type in by_category["object"] %}
Expand All @@ -29,7 +29,7 @@ namespace webnn_wire {
template <typename T>
using PerObjectType = ityp::array<ObjectType, T, {{len(by_category["object"])}}>;

} // namespace webnn_wire
} // namespace webnn::wire


#endif // WEBNN_WIRE_OBJECTTPYE_AUTOGEN_H_
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
//* See the License for the specific language governing permissions and
//* limitations under the License.

#include "webnn_wire/WireCmd_autogen.h"
#include "webnn/wire/WireCmd_autogen.h"

#include "common/Assert.h"
#include "webnn_wire/Wire.h"
#include "webnn/wire/Wire.h"

#include <algorithm>
#include <cstring>
Expand Down Expand Up @@ -376,7 +376,7 @@
}
{% endmacro %}

namespace webnn_wire {
namespace webnn::wire {

// Macro to simplify error handling, similar to DAWN_TRY but for DeserializeResult.
#define DESERIALIZE_TRY(EXPR) \
Expand Down Expand Up @@ -485,4 +485,4 @@ namespace webnn_wire {
{{ write_command_serialization_methods(command, True) }}
{% endfor %}

} // namespace webnn_wire
} // namespace webnn::wire
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

#include <webnn/webnn.h>

#include "webnn_wire/ObjectType_autogen.h"
#include "webnn/wire/ObjectType_autogen.h"

namespace webnn_wire {
namespace webnn::wire {

using ObjectId = uint32_t;
using ObjectGeneration = uint32_t;
Expand Down Expand Up @@ -141,6 +141,6 @@ namespace webnn_wire {
{{write_command_struct(command, True)}}
{% endfor %}

} // namespace webnn_wire
} // namespace webnn::wire

#endif // WEBNN_WIRE_WIRECMD_AUTOGEN_H_
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
#ifndef WEBNNNWIRE_CLIENT_APIOBJECTS_AUTOGEN_H_
#define WEBNNNWIRE_CLIENT_APIOBJECTS_AUTOGEN_H_

#include "webnn_wire/ObjectType_autogen.h"
#include "webnn_wire/client/ObjectBase.h"
#include "webnn/wire/ObjectType_autogen.h"
#include "webnn/wire/client/ObjectBase.h"

namespace webnn_wire::client {
namespace webnn::wire::client {

template <typename T>
struct ObjectTypeToTypeEnum {
Expand Down Expand Up @@ -49,6 +49,6 @@ namespace webnn_wire::client {
};

{% endfor %}
} // namespace webnn_wire::client
} // namespace webnn::wire::client

#endif // WEBNNNWIRE_CLIENT_APIOBJECTS_AUTOGEN_H_
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
//* limitations under the License.

#include "common/Log.h"
#include "webnn_wire/client/ApiObjects.h"
#include "webnn_wire/client/Client.h"
#include "webnn/wire/client/ApiObjects.h"
#include "webnn/wire/client/Client.h"

#include <algorithm>
#include <cstring>
#include <string>
#include <vector>

namespace webnn_wire::client {
namespace webnn::wire::client {
namespace {

//* Outputs an rvalue that's the number of elements a pointer member points to.
Expand Down Expand Up @@ -271,4 +271,4 @@ namespace webnn_wire::client {
const WebnnProcTable& GetProcs() {
return gProcTable;
}
} // namespace webnn_wire::client
} // namespace webnn::wire::client
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
#ifndef DAWNWIRE_CLIENT_CLIENTBASE_AUTOGEN_H_
#define DAWNWIRE_CLIENT_CLIENTBASE_AUTOGEN_H_

#include "webnn_wire/ChunkedCommandHandler.h"
#include "webnn_wire/WireCmd_autogen.h"
#include "webnn_wire/client/ApiObjects.h"
#include "webnn_wire/client/ObjectAllocator.h"
#include "webnn/wire/ChunkedCommandHandler.h"
#include "webnn/wire/WireCmd_autogen.h"
#include "webnn/wire/client/ApiObjects.h"
#include "webnn/wire/client/ObjectAllocator.h"

namespace webnn_wire::client {
namespace webnn::wire::client {

class ClientBase : public ChunkedCommandHandler, public ObjectIdProvider {
public:
Expand Down Expand Up @@ -66,6 +66,6 @@ namespace webnn_wire::client {
{% endfor %}
};

} // namespace webnn_wire::client
} // namespace webnn::wire::client

#endif // DAWNWIRE_CLIENT_CLIENTBASE_AUTOGEN_H_
Loading

0 comments on commit 57a8f15

Please sign in to comment.