From 248cc8ad46f5319599356ee531c730128a3b9a24 Mon Sep 17 00:00:00 2001 From: Rok Mandeljc Date: Fri, 8 Mar 2024 20:24:17 +0100 Subject: [PATCH] hooks: torch: add support for MKL-enabled torch builds Add support for MKL-enabled `torch` builds on Windows (e.g., the nightly `2.3.0.dev20240308+cpu` build). The `torch` hook now attempts to discover and collect DLLs from MKL and its dependencies (`mkl`, `tbb`, `intel-openmp`). This should prevent frozen program from silently crashing due to missing MKL DLLs, which slip past PyInstaller's binary dependency analysis due to being dynamically loaded at run time. --- news/712.update.rst | 4 ++ .../hooks/stdhooks/hook-torch.py | 61 ++++++++++++++++++- 2 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 news/712.update.rst diff --git a/news/712.update.rst b/news/712.update.rst new file mode 100644 index 000000000..208ca2c5f --- /dev/null +++ b/news/712.update.rst @@ -0,0 +1,4 @@ +Update ``torch`` hook to add support for MKL-enabled ``torch`` builds +on Windows (e.g., the nightly ``2.3.0.dev20240308+cpu`` build). The hook +now attempts to discover and collect DLLs from MKL and its dependencies +(``mkl``, ``tbb``, ``intel-openmp``). \ No newline at end of file diff --git a/src/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-torch.py b/src/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-torch.py index 01b6b3ae4..53dc636be 100644 --- a/src/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-torch.py +++ b/src/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-torch.py @@ -10,6 +10,8 @@ # SPDX-License-Identifier: GPL-2.0-or-later # ------------------------------------------------------------------ +import os + from PyInstaller.utils.hooks import ( logger, collect_data_files, @@ -20,7 +22,7 @@ ) if is_module_satisfies("PyInstaller >= 6.0"): - from PyInstaller.compat import is_linux + from PyInstaller.compat import is_linux, is_win from PyInstaller.utils.hooks import PY_DYLIB_PATTERNS module_collection_mode = "pyz+py" @@ -70,6 +72,63 @@ def _infer_nvidia_hiddenimports(): nvidia_hiddenimports = [] logger.info("hook-torch: inferred hidden imports for CUDA libraries: %r", nvidia_hiddenimports) hiddenimports += nvidia_hiddenimports + + # The Windows nightly build for torch 2.3.0 added dependency on MKL. The `mkl` distribution does not provide an + # importable package, but rather installs the DLLs in /Library/bin directory. Therefore, we cannot write a + # separate hook for it, and must collect the DLLs here. (Most of these DLLs are missed by PyInstaller's binary + # dependency analysis due to being dynamically loaded at run-time). + if is_win: + def _collect_mkl_dlls(): + import packaging.requirements + from _pyinstaller_hooks_contrib.compat import importlib_metadata + + # Check if torch depends on `mkl` + dist = importlib_metadata.distribution("torch") + requirements = [packaging.requirements.Requirement(req) for req in dist.requires or []] + requirements = [req.name for req in requirements if req.marker is None or req.marker.evaluate()] + if 'mkl' not in requirements: + logger.info('hook-torch: this torch build does not depend on MKL...') + return [] # This torch build does not depend on MKL + + # Find requirements of mkl - this should yield `intel-openmp` and `tbb`, which install DLLs in the same + # way as `mkl`. + try: + dist = importlib_metadata.distribution("mkl") + except importlib_metadata.PackageNotFoundError: + return [] # For some reason, `mkl` distribution is unavailable. + requirements = [packaging.requirements.Requirement(req) for req in dist.requires or []] + requirements = [req.name for req in requirements if req.marker is None or req.marker.evaluate()] + + requirements = ['mkl'] + requirements + + mkl_binaries = [] + logger.info('hook-torch: collecting DLLs from MKL and its dependencies: %r', requirements) + for requirement in requirements: + try: + dist = importlib_metadata.distribution(requirement) + except importlib_metadata.PackageNotFoundError: + continue + + # Go over files, and match DLLs in /Library/bin directory + for dist_file in dist.files: + if not dist_file.match('../../Library/bin/*.dll'): + continue + dll_file = dist.locate_file(dist_file).resolve() + mkl_binaries.append((str(dll_file), '.')) + + logger.info( + 'hook-torch: found MKL DLLs: %r', + sorted([os.path.basename(src_name) for src_name, dest_name in mkl_binaries]) + ) + return mkl_binaries + + try: + mkl_binaries = _collect_mkl_dlls() + except Exception: + # Log the exception, but make it non-fatal + logger.warning("hook-torch: failed to collect MKL DLLs!", exc_info=True) + mkl_binaries = [] + binaries += mkl_binaries else: datas = [(get_package_paths("torch")[1], "torch")]