From e381eba5ae9b07b5b477f0194c2cf0923ba470c4 Mon Sep 17 00:00:00 2001 From: Chris Barnes Date: Tue, 7 Dec 2021 11:12:15 +0000 Subject: [PATCH] Add SimpleLookup for non-float cases --- navis/nbl/nblast_funcs.py | 2 +- navis/nbl/smat.py | 85 +++++++++++++++++++++++++++++++------ tests/test_nbl/test_smat.py | 2 +- 3 files changed, 74 insertions(+), 15 deletions(-) diff --git a/navis/nbl/nblast_funcs.py b/navis/nbl/nblast_funcs.py index 9e8139f2..7a16ed03 100644 --- a/navis/nbl/nblast_funcs.py +++ b/navis/nbl/nblast_funcs.py @@ -96,7 +96,7 @@ def __init__(self, use_alpha=False, normalized=True, smat='auto', if limit_dist == "auto": try: - self.distance_upper_bound = self.score_fn.digitizers[0]._max + self.distance_upper_bound = self.score_fn.axes[0]._max except AttributeError: logger.warning("Could not infer distance upper bound from scoring function") self.distance_upper_bound = None diff --git a/navis/nbl/smat.py b/navis/nbl/smat.py index 2c6f07e6..8aa9ee8c 100644 --- a/navis/nbl/smat.py +++ b/navis/nbl/smat.py @@ -1,10 +1,12 @@ from __future__ import annotations +from abc import ABC, abstractmethod from itertools import permutations import sys import os from collections import Counter from concurrent.futures import ProcessPoolExecutor from typing import ( + Generic, Hashable, Iterator, Mapping, @@ -15,6 +17,7 @@ Iterable, Any, Tuple, + TypeVar, Union, ) import logging @@ -443,14 +446,70 @@ def parse_boundary(item: str): return tuple(float(i) for i in item[1:-1].split(",")), right -class Digitizer: +T = TypeVar("T") + + +class LookupAxis(ABC, Generic[T]): + """Class converting some data into a linear index.""" + + @abstractmethod + def __len__(self) -> int: + """Number of bins represented by this instance.""" + pass + + @abstractmethod + def __call__(self, value: Union[T, Sequence[T]]) -> Union[int, Sequence[int]]: + """Convert some data into a linear index. + + Parameters + ---------- + value : Union[T, Sequence[T]] + Value to convert into an index + + Returns + ------- + Union[int, Sequence[int]] + If a scalar was given, return a scalar; otherwise, a numpy array of ints. + """ + pass + + +class SimpleLookup(LookupAxis[Hashable]): + def __init__(self, items: List[Hashable]): + """Look up in a list of items and return their index. + + Parameters + ---------- + items : List[Hashable] + The item's position in the list is the index which will be returned. + + Raises + ------ + ValueError + items are non-unique. + """ + self.items = {item: idx for idx, item in enumerate(items)} + if len(self.items) != len(items): + raise ValueError("Items are not unique") + + def __len__(self) -> int: + return len(self.items) + + def __call__(self, value: Union[Hashable, Sequence[Hashable]]) -> Union[int, Sequence[int]]: + if np.isscalar(value): + return self.items[value] + else: + return np.array([self.items[v] for v in value], int) + + +class Digitizer(LookupAxis[float]): def __init__( self, boundaries: Sequence[float], clip: Tuple[bool, bool] = (True, True), right=False, ): - """Class converting continuous values into discrete indices given specific bin boundaries. + """Class converting continuous values into discrete indices. Parameters ---------- @@ -637,29 +696,29 @@ def __eq__(self, other: object) -> bool: class LookupNd: - def __init__(self, digitizers: List[Digitizer], cells: np.ndarray): - if [len(b) for b in digitizers] != list(cells.shape): + def __init__(self, axes: List[LookupAxis], cells: np.ndarray): + if [len(b) for b in axes] != list(cells.shape): raise ValueError("boundaries and cells have inconsistent bin counts") - self.digitizers = digitizers + self.axes = axes self.cells = cells def __call__(self, *args): - if len(args) != len(self.digitizers): + if len(args) != len(self.axes): raise TypeError( - f"Lookup takes {len(self.digitizers)} arguments but {len(args)} were given" + f"Lookup takes {len(self.axes)} arguments but {len(args)} were given" ) - idxs = tuple(d(arg) for d, arg in zip(self.digitizers, args)) + idxs = tuple(d(arg) for d, arg in zip(self.axes, args)) out = self.cells[idxs] return out class Lookup2d(LookupNd): - """Convenience class inheriting from LookupNd for the common 2D case. + """Convenience class inheriting from LookupNd for the common 2D float case. Provides IO with pandas DataFrames. """ - def __init__(self, digitizer0: Digitizer, digitizer1: Digitizer, cells: np.ndarray): + def __init__(self, axis0: Digitizer, axis1: Digitizer, cells: np.ndarray): """2D lookup table for convert NBLAST matches to scores. Commonly read from a ``pandas.DataFrame`` @@ -674,7 +733,7 @@ def __init__(self, digitizer0: Digitizer, digitizer1: Digitizer, cells: np.ndarr cells : np.ndarray Values to look up in the table. """ - super().__init__([digitizer0, digitizer1], cells) + super().__init__([axis0, axis1], cells) def to_dataframe(self) -> pd.DataFrame: """Convert the lookup table into a ``pandas.DataFrame``. @@ -689,8 +748,8 @@ def to_dataframe(self) -> pd.DataFrame: """ return pd.DataFrame( self.cells, - self.digitizers[0].to_strings(), - self.digitizers[1].to_strings(), + self.axes[0].to_strings(), + self.axes[1].to_strings(), ) @classmethod diff --git a/tests/test_nbl/test_smat.py b/tests/test_nbl/test_smat.py index 76d7f9fb..04d2d4f6 100644 --- a/tests/test_nbl/test_smat.py +++ b/tests/test_nbl/test_smat.py @@ -87,7 +87,7 @@ def test_lookup2d_roundtrip(): df = lookup.to_dataframe() lookup2 = Lookup2d.from_dataframe(df) assert np.allclose(lookup.cells, lookup2.cells) - for b1, b2 in zip(lookup.digitizers, lookup2.digitizers): + for b1, b2 in zip(lookup.axes, lookup2.axes): assert b1 == b2