From 9031d4806edb68a64ebb9c25552879b777a0be83 Mon Sep 17 00:00:00 2001 From: Sunita Nadampalli Date: Tue, 12 Mar 2024 19:05:37 -0500 Subject: [PATCH] aarch64: apply the cherrypicked onednn PR-1768 This is to improve the torch.compile() perf by 5.8x on AWS Graviton3 instances. This patching is required till PyTorch oneDNN is upgraded to v3.4. --- aarch64_linux/aarch64_wheel_ci_build.py | 3 +++ aarch64_linux/build_aarch64_wheel.py | 1 + 2 files changed, 4 insertions(+) diff --git a/aarch64_linux/aarch64_wheel_ci_build.py b/aarch64_linux/aarch64_wheel_ci_build.py index 24989ec16..6f3d09a70 100755 --- a/aarch64_linux/aarch64_wheel_ci_build.py +++ b/aarch64_linux/aarch64_wheel_ci_build.py @@ -111,6 +111,9 @@ def parse_arguments(): with open("/builder/mkldnn_fix/fix-xbyak-failure.patch") as f: check_call(["patch", "-p1"], stdin=f, cwd="/pytorch/third_party/ideep/mkl-dnn") + print("Applying mkl-dnn patch to improve torch.compile() perf") + os.system("cd /pytorch/third_party/ideep/mkl-dnn && patch -p1 < /builder/mkldnn_fix/onednn-pr1768-aarch64-add-acl-sbgemm-inner-product-primitive.patch") # noqa: E501 + os.system(f"cd /pytorch; {build_vars} python3 setup.py bdist_wheel") pytorch_wheel_name = complete_wheel("pytorch") print(f"Build Compelete. Created {pytorch_wheel_name}..") diff --git a/aarch64_linux/build_aarch64_wheel.py b/aarch64_linux/build_aarch64_wheel.py index 1615c78a6..e9dbb7caf 100755 --- a/aarch64_linux/build_aarch64_wheel.py +++ b/aarch64_linux/build_aarch64_wheel.py @@ -558,6 +558,7 @@ def start_build(host: RemoteHost, *, build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON" host.run_cmd("cd $HOME && git clone https://github.com/pytorch/builder.git") host.run_cmd("cd $HOME/pytorch/third_party/ideep/mkl-dnn && patch -p1 < $HOME/builder/mkldnn_fix/fix-xbyak-failure.patch") # noqa: E501 + host.run_cmd("cd $HOME/pytorch/third_party/ideep/mkl-dnn && patch -p1 < $HOME/builder/mkldnn_fix/onednn-pr1768-aarch64-add-acl-sbgemm-inner-product-primitive.patch") # noqa: E501 host.run_cmd(f"cd $HOME/pytorch && export ACL_ROOT_DIR=$HOME/ComputeLibrary && {build_vars} python3 setup.py bdist_wheel{build_opts}") # noqa: E501 print('Repair the wheel') pytorch_wheel_name = host.list_dir("pytorch/dist")[0]