Skip to content

Commit

Permalink
Refactor smat-checking
Browse files Browse the repository at this point in the history
  • Loading branch information
clbarnes committed Apr 15, 2021
1 parent 9d8f577 commit 978a99d
Showing 1 changed file with 41 additions and 43 deletions.
84 changes: 41 additions & 43 deletions navis/nbl/nblast_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,57 +115,21 @@ def parse_interval(self, s):
""".strip()


def is_score_function(fn: Callable[[float, float], float]):
f"""Ensure that the score function is valid for NBLAST.
{score_func_description}
"""
test_arr = np.array([0.5] * 3)
out = fn(test_arr, test_arr)

return isinstance(fn(0.5, 0.5), float) and out.shape == test_arr.shape and out.dtype == test_arr.dtype
class NBlaster:
f"""Implements version 2 of the NBLAST algorithm.
Please note that some properties are computed on initialization and
changing parameters (e.g. ``use_alpha``) at a later stage will mess things
up!
def parse_score_fn(smat, use_alpha):
f"""
The highly flexible ``smat`` argument converts raw point match parameters
nto a single score representing how good that match is.
into a single score representing how good that match is.
Most simply, it is an NBLAST score function.
{score_func_description}
If a ``pandas.DataFrame``, converts this into a ``navis.Lookup2d`` and uses as above.
If path-like, converts this into a dataframe and uses as above.
If ``None``, uses ``operator.mul``.
If ``'auto'`` (default), uses score matrices from FCWB (like R's nat.nblast).
"""
if smat is None:
smat = operator.mul
elif smat == 'auto':
if use_alpha:
smat = smat_path / 'smat_alpha_fcwb.csv'
else:
smat = smat_path / 'smat_fcwb.csv'

if isinstance(smat, (str, os.PathLike)):
smat = pd.read_csv(smat, index_col=0)

if isinstance(smat, pd.DataFrame):
smat = Lookup2d.from_dataframe(smat)

if not callable(smat):
raise ValueError("smat should be a callable, a path, a pandas.DataFrame, or 'auto'")

if not is_score_function(smat):
raise ValueError("smat is not a valid NBLAST score function, see documentation")


class NBlaster:
f"""Implements version 2 of the NBLAST algorithm.
Please note that some properties are computed on initialization and
changing parameters (e.g. ``use_alpha``) at a later stage will mess things
up!
{parse_score_fn.__doc__}
Parameters
----------
Expand All @@ -189,11 +153,45 @@ def __init__(self, use_alpha=False, normalized=True, smat='auto', progress=True)
self.normalized = normalized
self.progress = progress

self.score_fn = parse_score_fn(smat, use_alpha)
self.score_fn = self._parse_score_fn(smat)

self.self_hits = []
self.dotprops = []

def _parse_score_fn(self, smat):
if smat is None:
smat = operator.mul
elif smat == 'auto':
if self.use_alpha:
smat = smat_path / 'smat_alpha_fcwb.csv'
else:
smat = smat_path / 'smat_fcwb.csv'

if isinstance(smat, (str, os.PathLike)):
smat = pd.read_csv(smat, index_col=0)

if isinstance(smat, pd.DataFrame):
smat = Lookup2d.from_dataframe(smat)

if not callable(smat):
raise ValueError("smat should be a callable, a path, a pandas.DataFrame, or 'auto'")

if not isinstance(smat(0.5, 0.5), float):
raise ValueError("smat does not take 2 floats and return a float")

test_arr = np.array([0.5] * 3)
try:
out = smat(test_arr, test_arr)
except Exception as e:
raise ValueError(f"Failed to use smat with numpy arrays: {e}")

if out.shape != test_arr.shape:
raise ValueError(
f"smat produced inconsistent shape: input {test_arr.shape}; output {out.shape}"
)

return smat

def append(self, dotprops):
"""Append dotprops."""
if isinstance(dotprops, (NeuronList, list)):
Expand Down

0 comments on commit 978a99d

Please sign in to comment.