Skip to content

Commit

Permalink
WIP units
Browse files Browse the repository at this point in the history
  • Loading branch information
clbarnes committed Mar 24, 2022
1 parent 039283d commit 27d5e44
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 20 deletions.
5 changes: 2 additions & 3 deletions navis/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import matplotlib as mpl

from .units import ureg # noqa: F401

logger = logging.getLogger('navis')


Expand Down Expand Up @@ -69,9 +71,6 @@ def remove_log_handlers():
# Default color for neurons
default_color = (.95, .65, .04)

# Unit registry
ureg = pint.UnitRegistry()

# Set to true to prevent Viewer from ever showing
headless = os.environ.get('NAVIS_HEADLESS', 'False').lower() == 'true'
if headless:
Expand Down
29 changes: 23 additions & 6 deletions navis/nbl/nblast_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,18 @@
import pandas as pd

from concurrent.futures import ProcessPoolExecutor
from typing import Union, Optional, List
from typing import Union, Optional
from typing_extensions import Literal
import pint

from pint.quantity import Quantity

from navis.nbl.smat import Lookup2d, smat_fcwb

from .. import utils, config
from ..core import NeuronList, Dotprops, make_dotprops
from .base import Blaster, NestedIndices
from ..units import as_unit

__all__ = ['nblast', 'nblast_smart', 'nblast_allbyall', 'sim_to_dist']

Expand Down Expand Up @@ -105,23 +109,36 @@ def __init__(self, use_alpha=False, normalized=True, smat='auto',
else:
self.distance_upper_bound = limit_dist

def append(self, dotprops) -> NestedIndices:
def append(self, dotprops, ignore_units=False) -> NestedIndices:
"""Append dotprops.
Returns the numerical index appended dotprops.
If dotprops is a (possibly nested) sequence of dotprops,
return a (possibly nested) list of indices.
"""
if isinstance(dotprops, Dotprops):
return self._append_dotprops(dotprops)
return self._append_dotprops(dotprops, ignore_units)

try:
return [self.append(n) for n in dotprops]
return [self.append(n, ignore_units) for n in dotprops]
except TypeError: # i.e. not iterable
raise ValueError(f"Expected Dotprops or iterable thereof; got {type(dotprops)}")

def _append_dotprops(self, dotprops: Dotprops) -> int:
next_id = len(self)
def _append_dotprops(self, dotprops: Dotprops, ignore_units) -> int:
if not ignore_units:
# if isinstance(dotprops.units, pint.Quantity):
# if np.allclose(1, dotprops.units):
# units = dotprops.units.units
# else:
# logger.warning(
# "Dotprops coordinates are not unitary (%s). "
# "This might lead to unexpected results.",
# dotprops.units
# )
# elif dotprops:
# units = as_unit(dotprops.units)

# if as_unit(dotprops.units)
self.neurons.append(dotprops)
self.ids.append(dotprops.id)
# Calculate score for self hit
Expand Down
83 changes: 72 additions & 11 deletions navis/nbl/smat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,18 @@
import math
from collections import defaultdict

import pint
import numpy as np
import pandas as pd

from ..config import ureg
from ..core.neurons import Dotprops
from ..units import parse_quantity, reduce_units, as_unit, DIMENSIONLESS

logger = logging.getLogger(__name__)

DEFAULT_SEED = 1991
DIMENSIONLESS = pint.Unit("dimensionless")

epsilon = sys.float_info.epsilon
cpu_count = max(1, (os.cpu_count() or 2) - 1)
Expand Down Expand Up @@ -433,7 +437,23 @@ def is_monotonically_increasing(lst):
return True


def parse_boundary(item: str):
def parse_interval(item: str) -> Tuple[Tuple[float, float], bool, pint.Unit]:
"""Parse a string representing a half-open interval.
Parameters
----------
item : str
An interval formatted like ``[1 nm, inf nm)``
Returns
-------
Tuple[Tuple[float, float], bool, pint.Unit]
(
(lower_bound, upper_bound),
include_right,
units
)
"""
explicit_interval = item[0] + item[-1]
if explicit_interval == "[)":
right = False
Expand All @@ -443,7 +463,18 @@ def parse_boundary(item: str):
raise ValueError(
f"Enclosing characters '{explicit_interval}' do not match a half-open interval"
)
return tuple(float(i) for i in item[1:-1].split(",")), right

out = []
units = []
for s in item[1:-1].split(","):
q = parse_quantity(s)
out.append(q.magnitude)
units.append(q.units)

if len(out) != 2:
raise ValueError(f"Could not parse interval, got {len(out)} values instead of 2")

return tuple(out), right, reduce_units(units)


T = TypeVar("T")
Expand All @@ -452,6 +483,8 @@ def parse_boundary(item: str):
class LookupAxis(ABC, Generic[T]):
"""Class converting some data into a linear index."""

units: pint.Unit = DIMENSIONLESS

@abstractmethod
def __len__(self) -> int:
"""Number of bins represented by this instance."""
Expand All @@ -473,6 +506,13 @@ def __call__(self, value: Union[T, Sequence[T]]) -> Union[int, Sequence[int]]:
"""
pass

def same_units(self, unit: Union[str, pint.Unit]) -> bool:
try:
reduce_units([self.units, as_unit(unit)])
except ValueError:
return False
return True


class SimpleLookup(LookupAxis[Hashable]):
def __init__(self, items: List[Hashable]):
Expand Down Expand Up @@ -508,6 +548,8 @@ def __init__(
boundaries: Sequence[float],
clip: Tuple[bool, bool] = (True, True),
right=False,
*,
units=DIMENSIONLESS,
):
"""Class converting continuous values into discrete indices.
Expand Down Expand Up @@ -546,6 +588,7 @@ def __init__(
raise ValueError("Boundaries are not monotonically increasing")

self.boundaries = np.asarray(boundaries)
self.units = as_unit(units)

def __len__(self):
return len(self.boundaries) - 1
Expand All @@ -570,8 +613,14 @@ def to_strings(self) -> List[str]:
b = self.boundaries.copy()
b[0] = self._min
b[-1] = self._max

if self.units == DIMENSIONLESS:
unit = ""
else:
unit = " " + str(self.units)

return [
f"{lb}{lower},{upper}{rb}"
f"{lb}{lower}{unit},{upper}{unit}{rb}"
for lower, upper in zip(b[:-1], b[1:])
]

Expand All @@ -596,8 +645,10 @@ def from_strings(cls, interval_strs: Sequence[str]):
bounds: List[float] = []
last_upper = None
last_right = None
units = []
for item in interval_strs:
(lower, upper), right = parse_boundary(item)
(lower, upper), right, unit = parse_interval(item)
units.append(unit)
bounds.append(float(lower))

if last_right is not None:
Expand All @@ -613,10 +664,10 @@ def from_strings(cls, interval_strs: Sequence[str]):
last_upper = upper

bounds.append(float(last_upper))
return cls(bounds, right=last_right)
return cls(bounds, right=last_right, units=reduce_units(units))

@classmethod
def from_linear(cls, lower: float, upper: float, nbins: int, right=False):
def from_linear(cls, lower: float, upper: float, nbins: int, right=False, *, units=DIMENSIONLESS):
"""Choose digitizer boundaries spaced linearly between two values.
Input values will be clipped to fit within the given interval.
Expand All @@ -638,10 +689,10 @@ def from_linear(cls, lower: float, upper: float, nbins: int, right=False):
Digitizer
"""
arr = np.linspace(lower, upper, nbins + 1, endpoint=True)
return cls(arr, right=right)
return cls(arr, right=right, units=units)

@classmethod
def from_geom(cls, lowest_upper: float, highest_lower: float, nbins: int, right=False):
def from_geom(cls, lowest_upper: float, highest_lower: float, nbins: int, right=False, *, units=DIMENSIONLESS):
"""Choose digitizer boundaries in a geometric sequence.
Additional bins will be added above and below the given values.
Expand All @@ -664,7 +715,7 @@ def from_geom(cls, lowest_upper: float, highest_lower: float, nbins: int, right=
Digitizer
"""
arr = np.geomspace(lowest_upper, highest_lower, nbins - 1, True)
return cls(arr, clip=(False, False), right=right)
return cls(arr, clip=(False, False), right=right, units=units)

@classmethod
def from_data(cls, data: Sequence[float], nbins: int, right=False):
Expand All @@ -684,8 +735,13 @@ def from_data(cls, data: Sequence[float], nbins: int, right=False):
-------
Digitizer
"""
if isinstance(data, pint.Quantity):
data = data.magnitude
units = data.units
else:
units = DIMENSIONLESS
arr = np.quantile(data, np.linspace(0, 1, nbins + 1, True))
return cls(arr, right=right)
return cls(arr, right=right, units=units)

def __eq__(self, other: object) -> bool:
if not isinstance(other, Digitizer):
Expand All @@ -712,6 +768,10 @@ def __call__(self, *args):
out = self.cells[idxs]
return out

@property
def units(self):
return tuple(ax.units for ax in self.axes)


class Lookup2d(LookupNd):
"""Convenience class inheriting from LookupNd for the common 2D float case.
Expand Down Expand Up @@ -776,7 +836,8 @@ def _smat_fcwb(alpha=False):
fname = ("smat_fcwb.csv", "smat_alpha_fcwb.csv")[alpha]
fpath = smat_path / fname

return Lookup2d.from_dataframe(pd.read_csv(fpath, index_col=0))
lookup = Lookup2d.from_dataframe(pd.read_csv(fpath, index_col=0))
lookup.axes[0].units = ureg.Unit("micrometer")


def smat_fcwb(alpha=False):
Expand Down
97 changes: 97 additions & 0 deletions navis/units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import re
from typing import Optional, Union, Sequence

import numpy as np
import pint

DIMENSIONLESS = pint.Unit("dimensionless")

INF_RE = re.compile(r"(?P<is_neg>-?)inf(inity)?")

ureg = pint.UnitRegistry()
ureg.define("@alias micron = micrometer")


def parse_quantity(item: str) -> pint.Quantity:
"""Parse strings into ``pint.Quantity``, accounting for infinity.
Parameters
----------
item : str
A quantity string like those used by ``pint``.
Returns
-------
pint.Quantity
"""
item = item.strip()
try:
q = ureg.Quantity(item)
except pint.UndefinedUnitError as e:
first, *other = item.split()
m = INF_RE.match(first)
if m is None:
raise e

val = float("inf")
if m.groupdict()["is_neg"]:
val *= -1
unit = ureg.Unit(" ".join(other))
q = ureg.Quantity(val, unit)

return q


def as_unit(unit: Optional[Union[str, pint.Unit]]) -> pint.Unit:
"""Convert a string (or None) into a ``pint.Unit``
Parameters
----------
unit : Optional[Union[str, pint.Unit]]
Returns
-------
pint.Unit
If the ``unit`` argument was ``None``, return dimensionless.
"""
if unit is None:
return DIMENSIONLESS

if isinstance(unit, pint.Unit):
return unit

return ureg.Unit(unit)


def reduce_units(units: Sequence[Optional[Union[str, pint.Unit]]]) -> pint.Unit:
"""Reduce a sequence of units or unit-like strings down to a single ``pint.Unit``.
Dimensionless units are ignored.
Parameters
----------
units : Sequence[Optional[Union[str, pint.Unit]]]
``None`` is treated as dimensionless.
Returns
-------
pint.Unit
Consensus units of the sequence.
Raises
------
ValueError
If more than one non-dimensionless unit is found.
"""
# use np.unique instead of set operations here,
# because setting aliases in the registry affects
# __eq__ (comparisons as used by np.unique) but not
# __hash__ (as used by sets)
unit_set = np.unique([DIMENSIONLESS] + [as_unit(u1) for u1 in units])
if len(unit_set) == 1:
return DIMENSIONLESS
actuals = list(unit_set)
actuals.remove(DIMENSIONLESS)
if len(actuals) == 1:
return actuals[0]
raise ValueError(f"More than one real unit found: {sorted(unit_set)}")

0 comments on commit 27d5e44

Please sign in to comment.