From 072d8b2280569a2d13b91d3ed51546d201a57366 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Andr=C3=A9s=20Margffoy=20Tuay?= Date: Thu, 15 Oct 2020 03:59:51 -0500 Subject: [PATCH] Fix external DLL loading on wheels (#2811) * Fix external DLL loading on wheels * Use SetDefaultDllDirectoriess and AddDllDirectory * Add previous paths * Trigger debug * Do not call SetDefaultDllDirectories. * Do not call loadlibrary if the extensions were not compiled * Fix lint issues --- packaging/wheel/relocate.py | 6 ++++++ torchvision/extension.py | 23 +++++++++++++++++++++++ torchvision/io/_video_opt.py | 25 ++++++++++++++++++++++++- torchvision/io/image.py | 26 +++++++++++++++++++++++++- 4 files changed, 78 insertions(+), 2 deletions(-) diff --git a/packaging/wheel/relocate.py b/packaging/wheel/relocate.py index aa2ecb6036a..e9bd07bef97 100644 --- a/packaging/wheel/relocate.py +++ b/packaging/wheel/relocate.py @@ -259,6 +259,12 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): def relocate_dll_library(dumpbin, output_dir, output_library, binary): + """ + Relocate a DLL/PE shared library to be packaged on a wheel. + + Given a shared library, find the transitive closure of its dependencies, + rename and copy them into the wheel. + """ print('Relocating {0}'.format(binary)) binary_path = osp.join(output_library, binary) diff --git a/torchvision/extension.py b/torchvision/extension.py index 69433be3b0f..265c989a8ce 100644 --- a/torchvision/extension.py +++ b/torchvision/extension.py @@ -12,6 +12,29 @@ def _register_extensions(): # load the custom_op_library and register the custom ops lib_dir = os.path.dirname(__file__) + if os.name == 'nt': + # Register the main torchvision library location on the default DLL path + import ctypes + import sys + + kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True) + with_load_library_flags = hasattr(kernel32, 'AddDllDirectory') + prev_error_mode = kernel32.SetErrorMode(0x0001) + + if with_load_library_flags: + kernel32.AddDllDirectory.restype = ctypes.c_void_p + + if sys.version_info >= (3, 8): + os.add_dll_directory(lib_dir) + elif with_load_library_flags: + res = kernel32.AddDllDirectory(lib_dir) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' + raise err + + kernel32.SetErrorMode(prev_error_mode) + loader_details = ( importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index ae4b0f7c869..347b367ab27 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -13,7 +13,7 @@ _HAS_VIDEO_OPT = False try: - lib_dir = os.path.join(os.path.dirname(__file__), "..") + lib_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) loader_details = ( importlib.machinery.ExtensionFileLoader, @@ -22,6 +22,29 @@ extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) ext_specs = extfinder.find_spec("video_reader") + + if os.name == 'nt': + # Load the video_reader extension using LoadLibraryExW + import ctypes + import sys + + kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True) + with_load_library_flags = hasattr(kernel32, 'AddDllDirectory') + prev_error_mode = kernel32.SetErrorMode(0x0001) + + if with_load_library_flags: + kernel32.LoadLibraryExW.restype = ctypes.c_void_p + + if ext_specs is not None: + res = kernel32.LoadLibraryExW(ext_specs.origin, None, 0x00001100) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += (f' Error loading "{ext_specs.origin}" or any or ' + 'its dependencies.') + raise err + + kernel32.SetErrorMode(prev_error_mode) + if ext_specs is not None: torch.ops.load_library(ext_specs.origin) _HAS_VIDEO_OPT = True diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 1e93c0c8d2a..2279be3ad10 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -7,7 +7,7 @@ _HAS_IMAGE_OPT = False try: - lib_dir = osp.join(osp.dirname(__file__), "..") + lib_dir = osp.abspath(osp.join(osp.dirname(__file__), "..")) loader_details = ( importlib.machinery.ExtensionFileLoader, @@ -16,6 +16,30 @@ extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) # type: ignore[arg-type] ext_specs = extfinder.find_spec("image") + + if os.name == 'nt': + # Load the image extension using LoadLibraryExW + import ctypes + import sys + + kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True) + with_load_library_flags = hasattr(kernel32, 'AddDllDirectory') + prev_error_mode = kernel32.SetErrorMode(0x0001) + + kernel32.LoadLibraryW.restype = ctypes.c_void_p + if with_load_library_flags: + kernel32.LoadLibraryExW.restype = ctypes.c_void_p + + if ext_specs is not None: + res = kernel32.LoadLibraryExW(ext_specs.origin, None, 0x00001100) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += (f' Error loading "{ext_specs.origin}" or any or ' + 'its dependencies.') + raise err + + kernel32.SetErrorMode(prev_error_mode) + if ext_specs is not None: torch.ops.load_library(ext_specs.origin) _HAS_IMAGE_OPT = True