Skip to content

Commit

Permalink
aarch64: apply the cherrypicked onednn PR-1768
Browse files Browse the repository at this point in the history
This is to improve the torch.compile() perf by 5.8x
on AWS Graviton3 instances. This patching is required
till PyTorch oneDNN is upgraded to v3.4.
  • Loading branch information
snadampal committed Mar 7, 2024
1 parent 09a674e commit 32d9e8e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
3 changes: 3 additions & 0 deletions aarch64_linux/aarch64_wheel_ci_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def parse_arguments():
print("Applying mkl-dnn patch to fix crash due to /sys not accesible")
os.system("cd /pytorch/third_party/ideep/mkl-dnn && patch -p1 < /builder/mkldnn_fix/fix-xbyak-failure.patch")

print("Applying mkl-dnn patch to improve torch.compile() perf")
os.system("cd /pytorch/third_party/ideep/mkl-dnn && patch -p1 < /builder/mkldnn_fix/onednn-pr1768-aarch64-add-acl-sbgemm-inner-product-primitive.patch") # noqa: E501

os.system(f"cd /pytorch; {build_vars} python3 setup.py bdist_wheel")
pytorch_wheel_name = complete_wheel("pytorch")
print(f"Build Compelete. Created {pytorch_wheel_name}..")
1 change: 1 addition & 0 deletions aarch64_linux/build_aarch64_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ def start_build(host: RemoteHost, *,
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
host.run_cmd("cd $HOME && git clone https://github.com/pytorch/builder.git")
host.run_cmd("cd $HOME/pytorch/third_party/ideep/mkl-dnn && patch -p1 < $HOME/builder/mkldnn_fix/fix-xbyak-failure.patch") # noqa: E501
host.run_cmd("cd $HOME/pytorch/third_party/ideep/mkl-dnn && patch -p1 < $HOME/builder/mkldnn_fix/onednn-pr1768-aarch64-add-acl-sbgemm-inner-product-primitive.patch") # noqa: E501
host.run_cmd(f"cd $HOME/pytorch && export ACL_ROOT_DIR=$HOME/ComputeLibrary && {build_vars} python3 setup.py bdist_wheel{build_opts}") # noqa: E501
print('Repair the wheel')
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
Expand Down

0 comments on commit 32d9e8e

Please sign in to comment.