Skip to content

Commit

Permalink
PERF: Do not import torch to reduce itk import time
Browse files Browse the repository at this point in the history
https://discourse.itk.org/t/torch-import-time/6354
Torch takes time to import. In types.py torch is imported
to check if it's present or not in the environment, but
it's never used.
Now, torch is detected with importlib.metadata and imported
later.

Do the same thing with xarray
  • Loading branch information
tbaudier committed Dec 15, 2023
1 parent 7d8b95d commit 963a6f4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
12 changes: 9 additions & 3 deletions Wrapping/Generators/Python/itk/support/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#
# ==========================================================================*/

import importlib
from importlib.metadata import metadata
import os
import re
import functools
Expand All @@ -24,17 +26,17 @@

_HAVE_XARRAY = False
try:
import xarray as xr
metadata('xarray')

_HAVE_XARRAY = True
except ImportError:
pass
_HAVE_TORCH = False
try:
import torch
metadata('torch')

_HAVE_TORCH = True
except ImportError:
except importlib.metadata.PackageNotFoundError:
pass


Expand Down Expand Up @@ -84,6 +86,10 @@ def accept_array_like_xarray_torch(image_filter):
If a xarray DataArray is passed as an input, output itk.Image's are converted to xarray.DataArray's."""
import numpy as np
import itk
if _HAVE_XARRAY:
import xarray as xr
if _HAVE_TORCH:
import torch

@functools.wraps(image_filter)
def image_filter_wrapper(*args, **kwargs):
Expand Down
16 changes: 9 additions & 7 deletions Wrapping/Generators/Python/itk/support/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#
# ==========================================================================*/

import importlib
from importlib.metadata import metadata
from typing import Union, Optional, Tuple, TYPE_CHECKING
import os

Expand All @@ -26,17 +28,17 @@

_HAVE_XARRAY = False
try:
import xarray as xr
metadata('xarray')

_HAVE_XARRAY = True
except ImportError:
except importlib.metadata.PackageNotFoundError:
pass
_HAVE_TORCH = False
try:
import torch
metadata('torch')

_HAVE_TORCH = True
except ImportError:
except importlib.metadata.PackageNotFoundError:
pass

# noinspection PyPep8Naming
Expand Down Expand Up @@ -218,11 +220,11 @@ def initialize_c_types_once() -> (
ImageOrImageSource = Union[ImageBase, ImageSource]
# Can be coerced into an itk.ImageBase
if _HAVE_XARRAY and _HAVE_TORCH:
ImageLike = Union[ImageBase, ArrayLike, xr.DataArray, torch.Tensor]
ImageLike = Union[ImageBase, ArrayLike, "xr.DataArray", "torch.Tensor"]
elif _HAVE_XARRAY:
ImageLike = Union[ImageBase, ArrayLike, xr.DataArray]
ImageLike = Union[ImageBase, ArrayLike, "xr.DataArray"]
elif _HAVE_TORCH:
ImageLike = Union[ImageBase, ArrayLike, torch.Tensor]
ImageLike = Union[ImageBase, ArrayLike, "torch.Tensor"]
else:
ImageLike = Union[ImageBase, ArrayLike]

Expand Down

0 comments on commit 963a6f4

Please sign in to comment.