Skip to content

Commit

Permalink
fix Elastix binary search on Windows + formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
schlegelp committed Jan 4, 2025
1 parent ab4de9e commit b294837
Showing 1 changed file with 78 additions and 51 deletions.
129 changes: 78 additions & 51 deletions navis/transforms/elastix.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,18 @@
import subprocess
import shutil
import tempfile
import platform

import numpy as np
import pandas as pd

from .base import BaseTransform
from ..utils import make_iterable

_search_path = [i for i in os.environ['PATH'].split(os.pathsep) if len(i) > 0]
_search_path = [i for i in os.environ["PATH"].split(os.pathsep) if len(i) > 0]


def find_elastixbin(tool: str = 'transformix') -> str:
def find_elastixbin(tool: str = "transformix") -> str:
"""Find directory with elastix binaries."""
for path in _search_path:
path = pathlib.Path(path)
Expand All @@ -45,15 +46,20 @@ def find_elastixbin(tool: str = 'transformix') -> str:
raise


_elastixbin = find_elastixbin()
if platform.system() == "Windows":
# On Windows, we have to search for `transformix.exe`
# We can still invoke it as `transformix` via the command line though
_elastixbin = find_elastixbin("transformix.exe")
else:
_elastixbin = find_elastixbin("transformix")


def setup_elastix():
"""Set up to make elastix work from inside a Python session.
Briefly: elastix requires the `LD_LIBRARY_PATH` (Linux) or `LDY_LIBRARY_PATH`
(OSX) environment variables to (also) point to the directory with the
elastix `lib` directory. For reasons unknown to me, these varibles do not
elastix `lib` directory. For reasons unknown to me, these variables do not
make it into the Python session. Hence, we have to set them here explicitly.
Above info is based on: https://github.com/jasper-tms/pytransformix
Expand All @@ -64,56 +70,59 @@ def setup_elastix():
return

# Check if this variable already exists
var = os.environ.get('LD_LIBRARY_PATH', os.environ.get('LDY_LIBRARY_PATH', ''))
var = os.environ.get("LD_LIBRARY_PATH", os.environ.get("LDY_LIBRARY_PATH", ""))

# Get the actual path
path = (_elastixbin.parent / 'lib').absolute()
path = (_elastixbin.parent / "lib").absolute()

if str(path) not in var:
var = f'{path}{os.pathsep}{var}' if var else str(path)
var = f"{path}{os.pathsep}{var}" if var else str(path)

# Note that `LD_LIBRARY_PATH` works for both Linux and OSX
os.environ['LD_LIBRARY_PATH'] = var
os.environ["LD_LIBRARY_PATH"] = var
# As per navis/issues/112
os.environ['DYLD_LIBRARY_PATH'] = var
os.environ["DYLD_LIBRARY_PATH"] = var


setup_elastix()


def requires_elastix(func):
"""Check if elastix is available."""

@functools.wraps(func)
def wrapper(*args, **kwargs):
if not _elastixbin:
raise ValueError("Could not find elastix binaries. Please download "
"the releases page at https://github.com/SuperElastix/elastix, "
"unzip at a convenient location and add that "
"location to your PATH variable. Note that you "
"will also have to set a LD_LIBRARY_PATH (Linux) "
"or DYLD_LIBRARY_PATH (OSX) variable. See the "
"elastic manual (release page) for details.")
raise ValueError(
"Could not find elastix binaries. Please download "
"the releases page at https://github.com/SuperElastix/elastix, "
"unzip at a convenient location and add that "
"location to your PATH variable. Note that you "
"will also have to set a LD_LIBRARY_PATH (Linux) "
"or DYLD_LIBRARY_PATH (OSX) variable. See the "
"elastic manual (release page) for details."
)
return func(*args, **kwargs)

return wrapper


@requires_elastix
def elastix_version(as_string=False):
"""Get elastix version."""
p = subprocess.run([_elastixbin / 'elastix', '--version'],
capture_output=True)
p = subprocess.run([_elastixbin / "elastix", "--version"], capture_output=True)
if p.stderr:
raise BaseException(f'Error running elastix:\n{p.stderr.decode()}')
raise BaseException(f"Error running elastix:\n{p.stderr.decode()}")

version = p.stdout.decode('utf-8').rstrip()
version = p.stdout.decode("utf-8").rstrip()

# Extract version from "elastix version: 5.0.1"
version = version.split(':')[-1]
version = version.split(":")[-1]

if as_string:
return version
else:
return tuple(int(v) for v in version.split('.'))
return tuple(int(v) for v in version.split("."))


class ElastixTransform(BaseTransform):
Expand Down Expand Up @@ -146,44 +155,48 @@ def __init__(self, file: str, copy_files=[]):
self.file = pathlib.Path(file)
self.copy_files = copy_files

def __eq__(self, other: 'ElastixTransform') -> bool:
def __eq__(self, other: "ElastixTransform") -> bool:
"""Implement equality comparison."""
if isinstance(other, ElastixTransform):
if self.file == other.file:
return True
return False

def check_if_possible(self, on_error: str = 'raise'):
def check_if_possible(self, on_error: str = "raise"):
"""Check if this transform is possible."""
if not _elastixbin:
msg = 'Folder with elastix binaries not found. Make sure the ' \
'directory is in your PATH environment variable.'
if on_error == 'raise':
msg = (
"Folder with elastix binaries not found. Make sure the "
"directory is in your PATH environment variable."
)
if on_error == "raise":
raise BaseException(msg)
return msg
if not self.file.is_file():
msg = f'Transformation file {self.file} not found.'
if on_error == 'raise':
msg = f"Transformation file {self.file} not found."
if on_error == "raise":
raise BaseException(msg)
return msg

def copy(self) -> 'ElastixTransform':
def copy(self) -> "ElastixTransform":
"""Return copy."""
# Attributes not to copy
no_copy = []
# Generate new empty transform
x = self.__class__(self.file)
# Override with this neuron's data
x.__dict__.update({k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy})
x.__dict__.update(
{k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy}
)

return x

def write_input_file(self, points, filepath):
"""Write a numpy array in format required by transformix."""
with open(filepath, 'w') as f:
f.write('point\n{}\n'.format(len(points)))
with open(filepath, "w") as f:
f.write("point\n{}\n".format(len(points)))
for x, y, z in points:
f.write(f'{x:f} {y:f} {z:f}\n')
f.write(f"{x:f} {y:f} {z:f}\n")

def read_output_file(self, filepath) -> np.ndarray:
"""Load output file.
Expand All @@ -200,10 +213,10 @@ def read_output_file(self, filepath) -> np.ndarray:
"""
points = []
with open(filepath, 'r') as f:
with open(filepath, "r") as f:
for line in f.readlines():
output = line.split('OutputPoint = [ ')[1].split(' ]')[0]
points.append([float(i) for i in output.split(' ')])
output = line.split("OutputPoint = [ ")[1].split(" ]")[0]
points.append([float(i) for i in output.split(" ")])
return np.array(points)

def xform(self, points: np.ndarray, return_logs=False) -> np.ndarray:
Expand All @@ -223,16 +236,20 @@ def xform(self, points: np.ndarray, return_logs=False) -> np.ndarray:
Transformed points.
"""
self.check_if_possible(on_error='raise')
self.check_if_possible(on_error="raise")

if isinstance(points, pd.DataFrame):
# Make sure x/y/z columns are present
if np.any([c not in points for c in ['x', 'y', 'z']]):
raise ValueError('points DataFrame must have x/y/z columns.')
points = points[['x', 'y', 'z']].values
elif not (isinstance(points, np.ndarray) and points.ndim == 2 and points.shape[1] == 3):
raise TypeError('`points` must be numpy array of shape (N, 3) or '
'pandas DataFrame with x/y/z columns')
if np.any([c not in points for c in ["x", "y", "z"]]):
raise ValueError("points DataFrame must have x/y/z columns.")
points = points[["x", "y", "z"]].values
elif not (
isinstance(points, np.ndarray) and points.ndim == 2 and points.shape[1] == 3
):
raise TypeError(
"`points` must be numpy array of shape (N, 3) or "
"pandas DataFrame with x/y/z columns"
)

# Everything happens in a temporary directory
with tempfile.TemporaryDirectory() as tempdir:
Expand All @@ -244,13 +261,21 @@ def xform(self, points: np.ndarray, return_logs=False) -> np.ndarray:
_ = pathlib.Path(shutil.copy(f, p))

# Write points to file
in_file = p / 'inputpoints.txt'
in_file = p / "inputpoints.txt"
self.write_input_file(points, in_file)

out_file = p / 'outputpoints.txt'
out_file = p / "outputpoints.txt"

# Prepare the command
command = [_elastixbin / 'transformix', '-out', str(p), '-tp', str(self.file), '-def', str(in_file)]
command = [
_elastixbin / "transformix",
"-out",
str(p),
"-tp",
str(self.file),
"-def",
str(in_file),
]

# Keep track of current working directory
cwd = os.getcwd()
Expand All @@ -269,16 +294,18 @@ def xform(self, points: np.ndarray, return_logs=False) -> np.ndarray:
os.chdir(cwd)

if return_logs:
logfile = p / 'transformix.log'
logfile = p / "transformix.log"
if not logfile.is_file():
raise FileNotFoundError('No log file found.')
raise FileNotFoundError("No log file found.")
with open(logfile) as f:
logs = f.read()
return logs

if not out_file.is_file():
raise FileNotFoundError('Elastix transform did not produce any '
f'output:\n {proc.stdout.decode()}')
raise FileNotFoundError(
"Elastix transform did not produce any "
f"output:\n {proc.stdout.decode()}"
)

points_xf = self.read_output_file(out_file)

Expand Down

0 comments on commit b294837

Please sign in to comment.