diff --git a/navis/config.py b/navis/config.py index dd380d17..8b802906 100644 --- a/navis/config.py +++ b/navis/config.py @@ -17,6 +17,8 @@ import matplotlib as mpl +from .units import ureg # noqa: F401 + logger = logging.getLogger('navis') @@ -69,9 +71,6 @@ def remove_log_handlers(): # Default color for neurons default_color = (.95, .65, .04) -# Unit registry -ureg = pint.UnitRegistry() - # Set to true to prevent Viewer from ever showing headless = os.environ.get('NAVIS_HEADLESS', 'False').lower() == 'true' if headless: diff --git a/navis/nbl/nblast_funcs.py b/navis/nbl/nblast_funcs.py index 72f8bdf1..0da7b6b8 100644 --- a/navis/nbl/nblast_funcs.py +++ b/navis/nbl/nblast_funcs.py @@ -22,14 +22,18 @@ import pandas as pd from concurrent.futures import ProcessPoolExecutor -from typing import Union, Optional, List +from typing import Union, Optional from typing_extensions import Literal +import pint + +from pint.quantity import Quantity from navis.nbl.smat import Lookup2d, smat_fcwb from .. import utils, config from ..core import NeuronList, Dotprops, make_dotprops from .base import Blaster, NestedIndices +from ..units import as_unit __all__ = ['nblast', 'nblast_smart', 'nblast_allbyall', 'sim_to_dist'] @@ -105,7 +109,7 @@ def __init__(self, use_alpha=False, normalized=True, smat='auto', else: self.distance_upper_bound = limit_dist - def append(self, dotprops) -> NestedIndices: + def append(self, dotprops, ignore_units=False) -> NestedIndices: """Append dotprops. Returns the numerical index appended dotprops. @@ -113,15 +117,28 @@ def append(self, dotprops) -> NestedIndices: return a (possibly nested) list of indices. """ if isinstance(dotprops, Dotprops): - return self._append_dotprops(dotprops) + return self._append_dotprops(dotprops, ignore_units) try: - return [self.append(n) for n in dotprops] + return [self.append(n, ignore_units) for n in dotprops] except TypeError: # i.e. not iterable raise ValueError(f"Expected Dotprops or iterable thereof; got {type(dotprops)}") - def _append_dotprops(self, dotprops: Dotprops) -> int: - next_id = len(self) + def _append_dotprops(self, dotprops: Dotprops, ignore_units) -> int: + if not ignore_units: + # if isinstance(dotprops.units, pint.Quantity): + # if np.allclose(1, dotprops.units): + # units = dotprops.units.units + # else: + # logger.warning( + # "Dotprops coordinates are not unitary (%s). " + # "This might lead to unexpected results.", + # dotprops.units + # ) + # elif dotprops: + # units = as_unit(dotprops.units) + + # if as_unit(dotprops.units) self.neurons.append(dotprops) self.ids.append(dotprops.id) # Calculate score for self hit diff --git a/navis/nbl/smat.py b/navis/nbl/smat.py index 8aa9ee8c..a18fc8ed 100644 --- a/navis/nbl/smat.py +++ b/navis/nbl/smat.py @@ -28,14 +28,18 @@ import math from collections import defaultdict +import pint import numpy as np import pandas as pd +from ..config import ureg from ..core.neurons import Dotprops +from ..units import parse_quantity, reduce_units, as_unit, DIMENSIONLESS logger = logging.getLogger(__name__) DEFAULT_SEED = 1991 +DIMENSIONLESS = pint.Unit("dimensionless") epsilon = sys.float_info.epsilon cpu_count = max(1, (os.cpu_count() or 2) - 1) @@ -433,7 +437,23 @@ def is_monotonically_increasing(lst): return True -def parse_boundary(item: str): +def parse_interval(item: str) -> Tuple[Tuple[float, float], bool, pint.Unit]: + """Parse a string representing a half-open interval. + + Parameters + ---------- + item : str + An interval formatted like ``[1 nm, inf nm)`` + + Returns + ------- + Tuple[Tuple[float, float], bool, pint.Unit] + ( + (lower_bound, upper_bound), + include_right, + units + ) + """ explicit_interval = item[0] + item[-1] if explicit_interval == "[)": right = False @@ -443,7 +463,18 @@ def parse_boundary(item: str): raise ValueError( f"Enclosing characters '{explicit_interval}' do not match a half-open interval" ) - return tuple(float(i) for i in item[1:-1].split(",")), right + + out = [] + units = [] + for s in item[1:-1].split(","): + q = parse_quantity(s) + out.append(q.magnitude) + units.append(q.units) + + if len(out) != 2: + raise ValueError(f"Could not parse interval, got {len(out)} values instead of 2") + + return tuple(out), right, reduce_units(units) T = TypeVar("T") @@ -452,6 +483,8 @@ def parse_boundary(item: str): class LookupAxis(ABC, Generic[T]): """Class converting some data into a linear index.""" + units: pint.Unit = DIMENSIONLESS + @abstractmethod def __len__(self) -> int: """Number of bins represented by this instance.""" @@ -473,6 +506,13 @@ def __call__(self, value: Union[T, Sequence[T]]) -> Union[int, Sequence[int]]: """ pass + def same_units(self, unit: Union[str, pint.Unit]) -> bool: + try: + reduce_units([self.units, as_unit(unit)]) + except ValueError: + return False + return True + class SimpleLookup(LookupAxis[Hashable]): def __init__(self, items: List[Hashable]): @@ -508,6 +548,8 @@ def __init__( boundaries: Sequence[float], clip: Tuple[bool, bool] = (True, True), right=False, + *, + units=DIMENSIONLESS, ): """Class converting continuous values into discrete indices. @@ -546,6 +588,7 @@ def __init__( raise ValueError("Boundaries are not monotonically increasing") self.boundaries = np.asarray(boundaries) + self.units = as_unit(units) def __len__(self): return len(self.boundaries) - 1 @@ -570,8 +613,14 @@ def to_strings(self) -> List[str]: b = self.boundaries.copy() b[0] = self._min b[-1] = self._max + + if self.units == DIMENSIONLESS: + unit = "" + else: + unit = " " + str(self.units) + return [ - f"{lb}{lower},{upper}{rb}" + f"{lb}{lower}{unit},{upper}{unit}{rb}" for lower, upper in zip(b[:-1], b[1:]) ] @@ -596,8 +645,10 @@ def from_strings(cls, interval_strs: Sequence[str]): bounds: List[float] = [] last_upper = None last_right = None + units = [] for item in interval_strs: - (lower, upper), right = parse_boundary(item) + (lower, upper), right, unit = parse_interval(item) + units.append(unit) bounds.append(float(lower)) if last_right is not None: @@ -613,10 +664,10 @@ def from_strings(cls, interval_strs: Sequence[str]): last_upper = upper bounds.append(float(last_upper)) - return cls(bounds, right=last_right) + return cls(bounds, right=last_right, units=reduce_units(units)) @classmethod - def from_linear(cls, lower: float, upper: float, nbins: int, right=False): + def from_linear(cls, lower: float, upper: float, nbins: int, right=False, *, units=DIMENSIONLESS): """Choose digitizer boundaries spaced linearly between two values. Input values will be clipped to fit within the given interval. @@ -638,10 +689,10 @@ def from_linear(cls, lower: float, upper: float, nbins: int, right=False): Digitizer """ arr = np.linspace(lower, upper, nbins + 1, endpoint=True) - return cls(arr, right=right) + return cls(arr, right=right, units=units) @classmethod - def from_geom(cls, lowest_upper: float, highest_lower: float, nbins: int, right=False): + def from_geom(cls, lowest_upper: float, highest_lower: float, nbins: int, right=False, *, units=DIMENSIONLESS): """Choose digitizer boundaries in a geometric sequence. Additional bins will be added above and below the given values. @@ -664,7 +715,7 @@ def from_geom(cls, lowest_upper: float, highest_lower: float, nbins: int, right= Digitizer """ arr = np.geomspace(lowest_upper, highest_lower, nbins - 1, True) - return cls(arr, clip=(False, False), right=right) + return cls(arr, clip=(False, False), right=right, units=units) @classmethod def from_data(cls, data: Sequence[float], nbins: int, right=False): @@ -684,8 +735,13 @@ def from_data(cls, data: Sequence[float], nbins: int, right=False): ------- Digitizer """ + if isinstance(data, pint.Quantity): + data = data.magnitude + units = data.units + else: + units = DIMENSIONLESS arr = np.quantile(data, np.linspace(0, 1, nbins + 1, True)) - return cls(arr, right=right) + return cls(arr, right=right, units=units) def __eq__(self, other: object) -> bool: if not isinstance(other, Digitizer): @@ -712,6 +768,10 @@ def __call__(self, *args): out = self.cells[idxs] return out + @property + def units(self): + return tuple(ax.units for ax in self.axes) + class Lookup2d(LookupNd): """Convenience class inheriting from LookupNd for the common 2D float case. @@ -776,7 +836,8 @@ def _smat_fcwb(alpha=False): fname = ("smat_fcwb.csv", "smat_alpha_fcwb.csv")[alpha] fpath = smat_path / fname - return Lookup2d.from_dataframe(pd.read_csv(fpath, index_col=0)) + lookup = Lookup2d.from_dataframe(pd.read_csv(fpath, index_col=0)) + lookup.axes[0].units = ureg.Unit("micrometer") def smat_fcwb(alpha=False): diff --git a/navis/units.py b/navis/units.py new file mode 100644 index 00000000..85a39535 --- /dev/null +++ b/navis/units.py @@ -0,0 +1,97 @@ +import re +from typing import Optional, Union, Sequence + +import numpy as np +import pint + +DIMENSIONLESS = pint.Unit("dimensionless") + +INF_RE = re.compile(r"(?P-?)inf(inity)?") + +ureg = pint.UnitRegistry() +ureg.define("@alias micron = micrometer") + + +def parse_quantity(item: str) -> pint.Quantity: + """Parse strings into ``pint.Quantity``, accounting for infinity. + + Parameters + ---------- + item : str + A quantity string like those used by ``pint``. + + Returns + ------- + pint.Quantity + """ + item = item.strip() + try: + q = ureg.Quantity(item) + except pint.UndefinedUnitError as e: + first, *other = item.split() + m = INF_RE.match(first) + if m is None: + raise e + + val = float("inf") + if m.groupdict()["is_neg"]: + val *= -1 + unit = ureg.Unit(" ".join(other)) + q = ureg.Quantity(val, unit) + + return q + + +def as_unit(unit: Optional[Union[str, pint.Unit]]) -> pint.Unit: + """Convert a string (or None) into a ``pint.Unit`` + + Parameters + ---------- + unit : Optional[Union[str, pint.Unit]] + + Returns + ------- + pint.Unit + If the ``unit`` argument was ``None``, return dimensionless. + """ + if unit is None: + return DIMENSIONLESS + + if isinstance(unit, pint.Unit): + return unit + + return ureg.Unit(unit) + + +def reduce_units(units: Sequence[Optional[Union[str, pint.Unit]]]) -> pint.Unit: + """Reduce a sequence of units or unit-like strings down to a single ``pint.Unit``. + + Dimensionless units are ignored. + + Parameters + ---------- + units : Sequence[Optional[Union[str, pint.Unit]]] + ``None`` is treated as dimensionless. + + Returns + ------- + pint.Unit + Consensus units of the sequence. + + Raises + ------ + ValueError + If more than one non-dimensionless unit is found. + """ + # use np.unique instead of set operations here, + # because setting aliases in the registry affects + # __eq__ (comparisons as used by np.unique) but not + # __hash__ (as used by sets) + unit_set = np.unique([DIMENSIONLESS] + [as_unit(u1) for u1 in units]) + if len(unit_set) == 1: + return DIMENSIONLESS + actuals = list(unit_set) + actuals.remove(DIMENSIONLESS) + if len(actuals) == 1: + return actuals[0] + raise ValueError(f"More than one real unit found: {sorted(unit_set)}")