Skip to content

Commit

Permalink
hooks: torch: add support for MKL-enabled torch builds
Browse files Browse the repository at this point in the history
Add support for MKL-enabled `torch` builds on Windows (e.g., the
nightly `2.3.0.dev20240308+cpu` build). The `torch` hook now
attempts to discover and collect DLLs from MKL and its dependencies
(`mkl`, `tbb`, `intel-openmp`).

This should prevent frozen program from silently crashing due to
missing MKL DLLs, which slip past PyInstaller's binary dependency
analysis due to being dynamically loaded at run time.
  • Loading branch information
rokm committed Mar 8, 2024
1 parent 409bd3c commit 248cc8a
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
4 changes: 4 additions & 0 deletions news/712.update.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Update ``torch`` hook to add support for MKL-enabled ``torch`` builds
on Windows (e.g., the nightly ``2.3.0.dev20240308+cpu`` build). The hook
now attempts to discover and collect DLLs from MKL and its dependencies
(``mkl``, ``tbb``, ``intel-openmp``).
61 changes: 60 additions & 1 deletion src/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
# SPDX-License-Identifier: GPL-2.0-or-later
# ------------------------------------------------------------------

import os

from PyInstaller.utils.hooks import (
logger,
collect_data_files,
Expand All @@ -20,7 +22,7 @@
)

if is_module_satisfies("PyInstaller >= 6.0"):
from PyInstaller.compat import is_linux
from PyInstaller.compat import is_linux, is_win
from PyInstaller.utils.hooks import PY_DYLIB_PATTERNS

module_collection_mode = "pyz+py"
Expand Down Expand Up @@ -70,6 +72,63 @@ def _infer_nvidia_hiddenimports():
nvidia_hiddenimports = []
logger.info("hook-torch: inferred hidden imports for CUDA libraries: %r", nvidia_hiddenimports)
hiddenimports += nvidia_hiddenimports

# The Windows nightly build for torch 2.3.0 added dependency on MKL. The `mkl` distribution does not provide an
# importable package, but rather installs the DLLs in <env>/Library/bin directory. Therefore, we cannot write a
# separate hook for it, and must collect the DLLs here. (Most of these DLLs are missed by PyInstaller's binary
# dependency analysis due to being dynamically loaded at run-time).
if is_win:
def _collect_mkl_dlls():
import packaging.requirements
from _pyinstaller_hooks_contrib.compat import importlib_metadata

# Check if torch depends on `mkl`
dist = importlib_metadata.distribution("torch")
requirements = [packaging.requirements.Requirement(req) for req in dist.requires or []]
requirements = [req.name for req in requirements if req.marker is None or req.marker.evaluate()]
if 'mkl' not in requirements:
logger.info('hook-torch: this torch build does not depend on MKL...')
return [] # This torch build does not depend on MKL

# Find requirements of mkl - this should yield `intel-openmp` and `tbb`, which install DLLs in the same
# way as `mkl`.
try:
dist = importlib_metadata.distribution("mkl")
except importlib_metadata.PackageNotFoundError:
return [] # For some reason, `mkl` distribution is unavailable.
requirements = [packaging.requirements.Requirement(req) for req in dist.requires or []]
requirements = [req.name for req in requirements if req.marker is None or req.marker.evaluate()]

requirements = ['mkl'] + requirements

mkl_binaries = []
logger.info('hook-torch: collecting DLLs from MKL and its dependencies: %r', requirements)
for requirement in requirements:
try:
dist = importlib_metadata.distribution(requirement)
except importlib_metadata.PackageNotFoundError:
continue

# Go over files, and match DLLs in <env>/Library/bin directory
for dist_file in dist.files:
if not dist_file.match('../../Library/bin/*.dll'):
continue
dll_file = dist.locate_file(dist_file).resolve()
mkl_binaries.append((str(dll_file), '.'))

logger.info(
'hook-torch: found MKL DLLs: %r',
sorted([os.path.basename(src_name) for src_name, dest_name in mkl_binaries])
)
return mkl_binaries

try:
mkl_binaries = _collect_mkl_dlls()
except Exception:
# Log the exception, but make it non-fatal
logger.warning("hook-torch: failed to collect MKL DLLs!", exc_info=True)
mkl_binaries = []
binaries += mkl_binaries
else:
datas = [(get_package_paths("torch")[1], "torch")]

Expand Down

0 comments on commit 248cc8a

Please sign in to comment.