Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF] Modularize metric calculation #591

Merged
merged 27 commits into from
Jun 14, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ three-echo:
four-echo:
@py.test --cov-append --cov-report term-missing --cov=tedana -k test_integration_four_echo tedana/tests/test_integration.py

four-echo:
@py.test --cov-append --cov-report term-missing --cov=tedana -k test_integration_four_echo tedana/tests/test_integration.py

five-echo:
@py.test --cov-append --cov-report term-missing --cov=tedana -k test_integration_five_echo tedana/tests/test_integration.py

Expand Down
7 changes: 3 additions & 4 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,10 @@ API

.. autosummary::
:toctree: generated/
:template: function.rst

tedana.metrics.dependence_metrics
tedana.metrics.kundu_metrics
:template: module.rst

tedana.metrics.collect
tedana.metrics.dependence

.. _api_selection_ref:

Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ exclude=*build/
ignore = E126,E402,W504
per-file-ignores =
*/__init__.py:F401

[tool:pytest]
log_cli = true
11 changes: 10 additions & 1 deletion tedana/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
Functions to estimate S0 and T2* from multi-echo data.
"""
import logging
import scipy
import numpy as np
import scipy
from scipy import stats

from tedana import utils

LGR = logging.getLogger(__name__)
Expand Down Expand Up @@ -299,6 +301,13 @@ def fit_decay(data, tes, mask, adaptive_mask, fittype):
t2s_full = utils.unmask(t2s_full, mask)
s0_full = utils.unmask(s0_full, mask)

# set a hard cap for the T2* map
# anything that is 10x higher than the 99.5 %ile will be reset to 99.5 %ile
cap_t2s = stats.scoreatpercentile(t2s_limited.flatten(), 99.5,
interpolation_method='lower')
LGR.debug('Setting cap on T2* map at {:.5f}'.format(cap_t2s * 10))
t2s_limited[t2s_limited > cap_t2s * 10] = cap_t2s

return t2s_limited, s0_limited, t2s_full, s0_full


Expand Down
17 changes: 11 additions & 6 deletions tedana/decomposition/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,17 @@ def tedpca(data_cat, data_oc, combmode, mask, adaptive_mask, t2sG,
varex_norm = varex / varex.sum()

# Compute Kappa and Rho for PCA comps
# Normalize each component's time series
vTmixN = stats.zscore(comp_ts, axis=0)
comptable, _, _, _ = metrics.dependence_metrics(
data_cat, data_oc, comp_ts, adaptive_mask, tes, ref_img,
reindex=False, mmixN=vTmixN, algorithm=None,
label='mepca_', out_dir=out_dir, verbose=verbose)
required_metrics = [
'kappa', 'rho', 'countnoise', 'countsigFT2', 'countsigFS0',
'dice_FT2', 'dice_FS0', 'signal-noise_t',
'variance explained', 'normalized variance explained',
'd_table_score'
]
comptable, _ = metrics.collect.generate_metrics(
data_cat, data_oc, comp_ts, mask, adaptive_mask,
tes, ref_img,
metrics=required_metrics, sort_by=None
)

# varex_norm from PCA overrides varex_norm from dependence_metrics,
# but we retain the original
Expand Down
6 changes: 3 additions & 3 deletions tedana/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# emacs: -*- mode: python-mode; py-indent-offset: 4; tab-width: 4; indent-tabs-mode: nil -*-
# ex: set sts=4 ts=4 sw=4 et:

from .kundu_fit import (
dependence_metrics, kundu_metrics
from .collect import (
generate_metrics
)

__all__ = [
'dependence_metrics', 'kundu_metrics'
'generate_metrics'
]
183 changes: 183 additions & 0 deletions tedana/metrics/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""
Misc. utils for metric calculation.
"""
import logging

import numpy as np
from scipy import stats

LGR = logging.getLogger(__name__)


def dependency_resolver(dict_, requested_metrics, base_inputs):
"""
Identify all necessary metrics based on a list of requested metrics and
the metrics each one requires to be calculated, as defined in a dictionary.

Parameters
----------
dict_ : :obj:`dict`
Dictionary containing lists, where each key is a metric name and its
associated value is the list of metrics or inputs required to calculate
it.
requested_metrics : :obj:`list`
Child metrics for which the function will determine parents.
base_inputs : :obj:`list`
A list of inputs to the metric collection function, to differentiate
them from metrics to be calculated.

Returns
-------
required_metrics :obj:`list`
A comprehensive list of all metrics and inputs required to generate all
of the requested inputs.
"""
not_found = [k for k in requested_metrics if k not in dict_.keys()]
if not_found:
raise ValueError('Unknown metric(s): {}'.format(', '.join(not_found)))

required_metrics = requested_metrics
escape_counter = 0
while True:
required_metrics_new = required_metrics[:]
for k in required_metrics:
if k in dict_.keys():
new_metrics = dict_[k]
elif k not in base_inputs:
print("Warning: {} not found".format(k))
required_metrics_new += new_metrics
if set(required_metrics) == set(required_metrics_new):
# There are no more parent metrics to calculate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about a LGR.debug here?

break
else:
required_metrics = required_metrics_new
escape_counter += 1
if escape_counter >= 10:
LGR.warning('dependency_resolver in infinite loop. Escaping early.')
break
return required_metrics


def determine_signs(weights, axis=0):
"""
Determine component-wise optimal signs using voxel-wise parameter estimates.

Parameters
----------
weights : (S x C) array_like
Parameter estimates for optimally combined data against the mixing
matrix.

Returns
-------
signs : (C) array_like
Array of 1 and -1 values corresponding to the appropriate flips for the
mixing matrix's component time series.
"""
# compute skews to determine signs based on unnormalized weights,
signs = stats.skew(weights, axis=axis)
signs /= np.abs(signs)
return signs


def flip_components(*args, signs):
"""
Flip an arbitrary set of input arrays based on a set of signs.

Parameters
----------
*args : array_like
Any number of arrays with one dimension the same length as signs.
If multiple dimensions share the same size as signs, behavior of this
function will be unpredictable.
signs : array_like of :obj:`int`
Array of +/- 1 by which to flip the values in each argument.

Returns
-------
*args : array_like
Input arrays after sign flipping.
"""
assert signs.ndim == 1, 'Argument "signs" must be one-dimensional.'
for arg in args:
assert len(signs) in arg.shape, \
('Size of argument "signs" must match size of one dimension in '
'each of the input arguments.')
assert sum(x == len(signs) for x in arg.shape) == 1, \
('Only one dimension of each input argument can match the length '
'of argument "signs".')
# correct mixing & weights signs based on spatial distribution tails
return [arg * signs for arg in args]


def sort_df(df, by='kappa', ascending=False):
"""
Sort DataFrame and get index.

Parameters
----------
df : :obj:`pandas.DataFrame`
DataFrame to sort.
by : :obj:`str` or None, optional
Column by which to sort the DataFrame. Default is 'kappa'.
ascending : :obj:`bool`, optional
Whether to sort the DataFrame in ascending (True) or descending (False)
order. Default is False.

Returns
-------
df : :obj:`pandas.DataFrame`
DataFrame after sorting, with index resetted.
argsort : array_like
Sorting index.
"""
if by is None:
return df, df.index.values

# Order of kwargs is preserved at 3.6+
argsort = df[by].argsort()
if not ascending:
argsort = argsort[::-1]
df = df.loc[argsort].reset_index(drop=True)
return df, argsort


def apply_sort(*args, sort_idx, axis=0):
"""
Apply a sorting index to an arbitrary set of arrays.
"""
for arg in args:
assert arg.shape[axis] == len(sort_idx)
return [np.take(arg, sort_idx, axis=axis) for arg in args]


def check_mask(data, mask):
"""
Check that no zero-variance voxels remain in masked data.

Parameters
----------
data : (S [x E] x T) array_like
Data to be masked and evaluated.
mask : (S) array_like
Boolean mask.

Raises
------
ValueError
"""
assert data.ndim <= 3
assert mask.shape[0] == data.shape[0]
masked_data = data[mask, ...]
dims_to_check = list(range(1, data.ndim))
for dim in dims_to_check:
# ignore singleton dimensions
if masked_data.shape[dim] == 1:
continue

masked_data_std = masked_data.std(axis=dim)
zero_idx = np.where(masked_data_std == 0)
n_bad_voxels = len(zero_idx[0])
if n_bad_voxels > 0:
raise ValueError('{0} voxels in masked data have zero variance. '
'Mask is too liberal.'.format(n_bad_voxels))
Loading