Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

feat: oneDNN wrapper based on oneDNN_jll #156

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ scripts
test_ext

benchmarks/results

deps
1 change: 1 addition & 0 deletions .typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ numer = "numer"
nd = "nd"
Ba = "Ba"
skipt = "skipt"
abd = "abd"
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.2.2"
version = "1.3.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
CpuId = "adafc99b-e345-5852-983c-f28acb93d879"
Expand All @@ -28,6 +29,7 @@ SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
oneDNN_jll = "3523a63d-8698-5b6f-b2c2-68eaa6bde0f0"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Expand All @@ -54,9 +56,10 @@ AMDGPU = "0.9.6, 1"
AppleAccelerate = "0.4"
ArrayInterface = "7.9"
BLISBLAS = "0.1"
CEnum = "0.5"
CUDA = "5.3.2"
ChainRulesCore = "1.24"
Compat = "4.15.0"
Compat = "4.15"
CpuId = "0.3"
DispatchDoctor = "0.4.12"
EnzymeCore = "0.7.7"
Expand All @@ -83,3 +86,4 @@ Statistics = "1.10"
Tracker = "0.2.34"
cuDNN = "1.3"
julia = "1.10"
oneDNN_jll = "3.5.3"
7 changes: 7 additions & 0 deletions generators/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[deps]
Clang = "40e3b903-d033-50b4-a0cc-940c62c95e31"
oneDNN_jll = "3523a63d-8698-5b6f-b2c2-68eaa6bde0f0"

[compat]
Clang = "0.18"
oneDNN_jll = "3.5.3"
212 changes: 212 additions & 0 deletions generators/generator.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
[general]
# it could also be an expression as long as `Meta.parse` can parse this string successfully.
# basically, it should be the `expression` in the following code:
# ccall((function_name, expression), returntype, (argtype1, ...), argvalue1, ...)
library_name = "libdnnl"

# this entry allows you to specify different library names for different headers.
# in the following example:
# library_names = {"config.h" = "libclang_config", "libclang_p.*.h" = "libclang_patch"}
# those functions in the `config.h` will be generated as:
# ccall((function_name, libclang_config), returntype, (argtype1, ...), argvalue1, ...)
library_names = {}

# output file path relative to the working directory
output_file_path = "../src/onednn/lib.jl"

# if these are set, common file (types and constants) and API file (functions) will be separated
# this is for compatibility, so prologue and epilogue are not supported.
# output_api_file_path = "api.jl"
# output_common_file_path = "common.jl"

# if this entry is not empty, the generator will print the code below to the `output_file_path`.
# module module_name
#
# end # module
module_name = "Lib"

# if this entry is not empty, the generator will print the code below to the `output_file_path`.
# using jll_pkg_name
# export jll_pkg_name
jll_pkg_name = "oneDNN_jll"

# for packages that have extra JLL package dependencies
jll_pkg_extra = []

# identifiers that starts with the string listed in this entry will be exported.
export_symbol_prefixes = ["CX", "clang_"]

# the code in the following file will be copy-pasted to `output_file_path` before the generated code.
# this is often used for applying custom patches, e.g. adding missing definitions.
prologue_file_path = "prologue.jl"

# the code in the following file will be copy-pasted to `output_file_path` after the generated code.
# this is often used for applying custom patches.
epilogue_file_path = ""

# node with an id in the `output_ignorelist` will be ignored in the printing passes.
# this is very useful for custom editing.
output_ignorelist = [
"CINDEX_EXPORTS",
"CINDEX_VERSION",
"CINDEX_VERSION_STRING",
"CINDEX_LINKAGE",
"CINDEX_DEPRECATED",
"LLVM_CLANG_C_STRICT_PROTOTYPES_BEGIN",
"LLVM_CLANG_C_STRICT_PROTOTYPES_END",
"LLVM_CLANG_C_EXTERN_C_BEGIN",
"LLVM_CLANG_C_EXTERN_C_END"
]

# Julia's `@enum` do not allow duplicated values, so by default, C enums are translated to
# CEnum.jl's `@cenum`.
# if this entry is true, `@enum` is used and those duplicated enum constants are just commented.
use_julia_native_enum_type = false

# use `@cenum` but do not print `using CEnum`.
# this is useful in the case of using `CEnum` directly in the source tree instead of using `CEnum` as a dependency
print_using_CEnum = false

# Print enums directly as integers without @(c)enum wrapper
# Override above two options
print_enum_as_integer = false

# use deterministic symbol instead of `gensym`-generated `var"##XXX"`
use_deterministic_symbol = true

# by default, only those declarations in the local header file are processed.
# those declarations in the system headers will be treated specially and will be generated if necessary.
# if you'd like to generate all of the symbols in the system headers, please set this option to false.
is_local_header_only = true

# set this option to false if you'd like to ignore the symbols(even if necessary) in the system headers.
generate_isystem_symbols = true

# if this option is set to true, C code with a style of
# ```c
# typedef struct {
# int x;
# } my_struct;
# ```
# will be generated as:
# ```julia
# struct my_struct
# x::Cint
# end
# ```
# instead of
# ```julia
# struct var"##Ctag#NUM"
# x::Cint
# end
# const my_struct = var"##Ctag#NUM"
# ```
smart_de_anonymize = true

# if set to true, static functions will be ignored
skip_static_functions = false

# EXPERIMENTAL
# if this option is set to true, those structs that are not necessary to be an
# immutable struct will be generated as a mutable struct.
# this option is default to false, do read the paragraph below before using this feature.
auto_mutability = false

# add inner constructor `Foo() = new()`
auto_mutability_with_new = true

# if you feel like certain structs should not be generated as mutable struct, please add them in the following list.
# for example, if a C function accepts a `Vector` of some type as its argument like:
# void foo(mutable_type *list, int n);
# when calling this function via `ccall`, passing a `Vector{mutable_type}(undef, n)` to the first
# argument will trigger a crash, the reason is mutable structs are not stored inline within a `Vector`,
# one should use `Ref{NTuple{n,mutable_type}}()` instead.
# this is not convenient and that's where the `auto_mutability_ignorelist` comes in.
auto_mutability_ignorelist = []

# opposite to `auto_mutability_ignorelist` and has a higher priority
auto_mutability_includelist = []

# if set to "raw", extract and dump raw c comment;
# if set to "doxygen", parse and format doxygen comment.
# note: by default, Clang only parses doxygen comment, pass `-fparse-all-comments` to Clang in order to parse non-doxygen comments.
extract_c_comment_style = "doxygen"

# Pass a function to explicitly generate documentation. It will be called like
# `callback_documentation(node::ExprNode, doc::Vector{String})` if it is
# set. The `doc` argument will contain the docs parsed from the headers if
# `extract_c_comment_style` is set, otherwise it will be an empty vector.
#
# Do *not* set this in the TOML file, it should be set in the generator script
# to a function that takes in an ExprNode and returns a String[] (string
# vector).
# callback_documentation = ""

# if set to true, single line comment will be printed as """comment""" instead of """\ncomment\n"""
fold_single_line_comment = false

# if set to "outofline", documentation of struct fields will be collected at the "Fields" section of the struct
# if set to "inline", documentation of struct fields will go right above struct definition
struct_field_comment_style = "outofline"

# if set to "outofline", documentation of enumerators will be collected at the "Enumerators" section of the enum
enumerator_comment_style = "outofline"

# if set to true, C function prototype will be included in documentation
show_c_function_prototype = false

[codegen]
# map C's bool to Julia's Bool instead of `Cuchar` a.k.a `UInt8`.
use_julia_bool = true

# set this to true if the C routine always expects a NUL-terminated string.
# TODO: support filtering
always_NUL_terminated_string = true

# generate strictly typed function
is_function_strictly_typed = false

# if true, opaque pointers in function arguments will be translated to `Ptr{Cvoid}`.
opaque_func_arg_as_PtrCvoid = false

# if true, opaque types are translated to `mutable struct` instead of `Cvoid`.
opaque_as_mutable_struct = true

# if true, use Julia 1.5's new `@ccall` macro
use_ccall_macro = true

# if true, variadic functions are wrapped with `@ccall` macro. Otherwise variadic functions are ignored.
wrap_variadic_function = false

# generate getproperty/setproperty! methods for the types in the following list
field_access_method_list = []

# the generator will prefix the function argument names in the following list with a "_" to
# prevent the generated symbols from conflicting with the symbols defined and exported in Base.
function_argument_conflict_symbols = []

# emit constructors for all custom-layout structs like bitfield in the list,
# or set to `true` to do so for all such structs
add_record_constructors = []

[codegen.macro]
# it‘s highly recommended to set this entry to "basic".
# if you'd like to skip all of the macros, please set this entry to "disable".
# if you'd like to translate function-like macros to Julia, please set this entry to "aggressive".
macro_mode = "basic"

# function-like macros in the following list will always be translated.
functionlike_macro_includelist = [
"CINDEX_VERSION_ENCODE",
]

# if true, the generator prints the following message as comments.
# "# Skipping MacroDefinition: ..."
add_comment_for_skipped_macro = true

# if true, ignore any macros that is suffixed with "_H" or in the `ignore_header_guards_with_suffixes` list
ignore_header_guards = true
ignore_header_guards_with_suffixes = []

# if true, ignore those pure definition macros in the C code
ignore_pure_definition = true
6 changes: 6 additions & 0 deletions generators/prologue.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using CEnum: @cenum

const NULL = C_NULL

# This file is automatically generated by Clang.jl. Don't edit it manually. If needed,
# look at the "generators/" directory and modify the relevant files there.
62 changes: 62 additions & 0 deletions generators/wrap.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using Clang.Generators
using oneDNN_jll

cur_dir = pwd()

cd(@__DIR__)

include_dir = joinpath(oneDNN_jll.artifact_dir, "include")

options = load_options(joinpath(@__DIR__, "generator.toml"))

onednn_headers = [
joinpath(include_dir, "dnnl.h"),
joinpath(include_dir, "dnnl_types.h"),
joinpath(include_dir, "dnnl_config.h"),
joinpath(include_dir, "dnnl_version.h")
]

args = get_default_args()
push!(args, "-I$include_dir")

ctx = create_context(onednn_headers, args, options)

# run generator
build!(ctx, BUILDSTAGE_NO_PRINTING)

function rewrite!(e::Expr)
# const DNNL_RUNTIME_SIZE_VAL = size_t(DNNL_RUNTIME_DIM_VAL)
if e.head == :const && e.args[1] isa Expr && e.args[1].head == :(=) &&
e.args[1].args[1] == :DNNL_RUNTIME_SIZE_VAL && e.args[1].args[2] isa Expr &&
e.args[1].args[2].head == :call && e.args[1].args[2].args[1] == :size_t &&
e.args[1].args[2].args[2] == :DNNL_RUNTIME_DIM_VAL
e.args[1].args[2] = unsigned(typemin(Int64))
return
end
# const DNNL_RUNTIME_DIM_VAL = INT64_MIN
if e.head == :const && e.args[1] isa Expr && e.args[1].head == :(=) &&
e.args[1].args[1] == :DNNL_RUNTIME_DIM_VAL && e.args[1].args[2] == :INT64_MIN
e.args[1].args[2] = typemin(Int64)
return
end
# const DNNL_RUNTIME_S32_VAL = DNNL_RUNTIME_S32_VAL_REP
if e.head == :const && e.args[1] isa Expr && e.args[1].head == :(=) &&
e.args[1].args[1] == :DNNL_RUNTIME_S32_VAL &&
e.args[1].args[2] == :DNNL_RUNTIME_S32_VAL_REP
e.args[1].args[2] = 0
return
end
return
end

function rewrite!(dag::ExprDAG)
for node in get_nodes(dag), expr in get_exprs(node)
rewrite!(expr)
end
end

rewrite!(ctx.dag)

build!(ctx, BUILDSTAGE_PRINTING_ONLY)

cd(cur_dir)
4 changes: 4 additions & 0 deletions src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ const CRC = ChainRulesCore

include("utils.jl")
include("traits.jl")

include("onednn/oneDNN.jl")

include("impl/Impl.jl")

include("api/API.jl")

@compat(public,
Expand Down
Loading
Loading