Skip to content

Commit

Permalink
Add SimpleLookup for non-float cases
Browse files Browse the repository at this point in the history
  • Loading branch information
clbarnes committed Dec 7, 2021
1 parent 54b1da0 commit e381eba
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 15 deletions.
2 changes: 1 addition & 1 deletion navis/nbl/nblast_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self, use_alpha=False, normalized=True, smat='auto',

if limit_dist == "auto":
try:
self.distance_upper_bound = self.score_fn.digitizers[0]._max
self.distance_upper_bound = self.score_fn.axes[0]._max
except AttributeError:
logger.warning("Could not infer distance upper bound from scoring function")
self.distance_upper_bound = None
Expand Down
85 changes: 72 additions & 13 deletions navis/nbl/smat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from itertools import permutations
import sys
import os
from collections import Counter
from concurrent.futures import ProcessPoolExecutor
from typing import (
Generic,
Hashable,
Iterator,
Mapping,
Expand All @@ -15,6 +17,7 @@
Iterable,
Any,
Tuple,
TypeVar,
Union,
)
import logging
Expand Down Expand Up @@ -443,14 +446,70 @@ def parse_boundary(item: str):
return tuple(float(i) for i in item[1:-1].split(",")), right


class Digitizer:
T = TypeVar("T")


class LookupAxis(ABC, Generic[T]):
"""Class converting some data into a linear index."""

@abstractmethod
def __len__(self) -> int:
"""Number of bins represented by this instance."""
pass

@abstractmethod
def __call__(self, value: Union[T, Sequence[T]]) -> Union[int, Sequence[int]]:
"""Convert some data into a linear index.
Parameters
----------
value : Union[T, Sequence[T]]
Value to convert into an index
Returns
-------
Union[int, Sequence[int]]
If a scalar was given, return a scalar; otherwise, a numpy array of ints.
"""
pass


class SimpleLookup(LookupAxis[Hashable]):
def __init__(self, items: List[Hashable]):
"""Look up in a list of items and return their index.
Parameters
----------
items : List[Hashable]
The item's position in the list is the index which will be returned.
Raises
------
ValueError
items are non-unique.
"""
self.items = {item: idx for idx, item in enumerate(items)}
if len(self.items) != len(items):
raise ValueError("Items are not unique")

def __len__(self) -> int:
return len(self.items)

def __call__(self, value: Union[Hashable, Sequence[Hashable]]) -> Union[int, Sequence[int]]:
if np.isscalar(value):
return self.items[value]
else:
return np.array([self.items[v] for v in value], int)


class Digitizer(LookupAxis[float]):
def __init__(
self,
boundaries: Sequence[float],
clip: Tuple[bool, bool] = (True, True),
right=False,
):
"""Class converting continuous values into discrete indices given specific bin boundaries.
"""Class converting continuous values into discrete indices.
Parameters
----------
Expand Down Expand Up @@ -637,29 +696,29 @@ def __eq__(self, other: object) -> bool:


class LookupNd:
def __init__(self, digitizers: List[Digitizer], cells: np.ndarray):
if [len(b) for b in digitizers] != list(cells.shape):
def __init__(self, axes: List[LookupAxis], cells: np.ndarray):
if [len(b) for b in axes] != list(cells.shape):
raise ValueError("boundaries and cells have inconsistent bin counts")
self.digitizers = digitizers
self.axes = axes
self.cells = cells

def __call__(self, *args):
if len(args) != len(self.digitizers):
if len(args) != len(self.axes):
raise TypeError(
f"Lookup takes {len(self.digitizers)} arguments but {len(args)} were given"
f"Lookup takes {len(self.axes)} arguments but {len(args)} were given"
)

idxs = tuple(d(arg) for d, arg in zip(self.digitizers, args))
idxs = tuple(d(arg) for d, arg in zip(self.axes, args))
out = self.cells[idxs]
return out


class Lookup2d(LookupNd):
"""Convenience class inheriting from LookupNd for the common 2D case.
"""Convenience class inheriting from LookupNd for the common 2D float case.
Provides IO with pandas DataFrames.
"""

def __init__(self, digitizer0: Digitizer, digitizer1: Digitizer, cells: np.ndarray):
def __init__(self, axis0: Digitizer, axis1: Digitizer, cells: np.ndarray):
"""2D lookup table for convert NBLAST matches to scores.
Commonly read from a ``pandas.DataFrame``
Expand All @@ -674,7 +733,7 @@ def __init__(self, digitizer0: Digitizer, digitizer1: Digitizer, cells: np.ndarr
cells : np.ndarray
Values to look up in the table.
"""
super().__init__([digitizer0, digitizer1], cells)
super().__init__([axis0, axis1], cells)

def to_dataframe(self) -> pd.DataFrame:
"""Convert the lookup table into a ``pandas.DataFrame``.
Expand All @@ -689,8 +748,8 @@ def to_dataframe(self) -> pd.DataFrame:
"""
return pd.DataFrame(
self.cells,
self.digitizers[0].to_strings(),
self.digitizers[1].to_strings(),
self.axes[0].to_strings(),
self.axes[1].to_strings(),
)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion tests/test_nbl/test_smat.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_lookup2d_roundtrip():
df = lookup.to_dataframe()
lookup2 = Lookup2d.from_dataframe(df)
assert np.allclose(lookup.cells, lookup2.cells)
for b1, b2 in zip(lookup.digitizers, lookup2.digitizers):
for b1, b2 in zip(lookup.axes, lookup2.axes):
assert b1 == b2


Expand Down

0 comments on commit e381eba

Please sign in to comment.