Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF] Modularize metric calculation #436

Closed
wants to merge 15 commits into from
7 changes: 7 additions & 0 deletions tedana/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,13 @@ def fit_decay(data, tes, mask, masksum, fittype):
t2s_full[masksum == 1] = t2ss[masksum == 1, 0]
s0_full[masksum == 1] = s0vs[masksum == 1, 0]

# set a hard cap for the T2* map
# anything that is 10x higher than the 99.5 %ile will be reset to 99.5 %ile
cap_t2s = stats.scoreatpercentile(t2s_limited.flatten(), 99.5,
interpolation_method='lower')
LGR.debug('Setting cap on T2* map at {:.5f}'.format(cap_t2s * 10))
t2s_limited[t2s_limited > cap_t2s * 10] = cap_t2s

return t2s_limited, s0_limited, t2ss, s0vs, t2s_full, s0_full


Expand Down
5 changes: 4 additions & 1 deletion tedana/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
# ex: set sts=4 ts=4 sw=4 et:

from .kundu_fit import (
dependence_metrics, kundu_metrics, get_coeffs, computefeats2
kundu_metrics, get_coeffs, computefeats2
)
from .dependence import (
dependence_metrics
)

__all__ = [
Expand Down
55 changes: 55 additions & 0 deletions tedana/metrics/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Misc. utils for metric calculation.
"""
import numpy as np
from scipy import stats

from tedana.stats import computefeats2


def determine_signs(weights, axis=0):
"""
Determine component-wise optimal signs using voxel-wise parameter estimates.

Parameters
----------
weights : (S x C) array_like
Parameter estimates for optimally combined data against the mixing
matrix.

Returns
-------
signs : (C) array_like
Array of 1 and -1 values corresponding to the appropriate flips for the
mixing matrix's component time series.
"""
# compute skews to determine signs based on unnormalized weights,
signs = stats.skew(weights, axis=axis)
signs /= np.abs(signs)
return signs


def flip_components(*args, signs):
# correct mixing & weights signs based on spatial distribution tails
return [arg * signs for arg in args]


def sort_df(df, by='kappa', ascending=False):
"""
Sort DataFrame and get index.
"""
# Order of kwargs is preserved at 3.6+
argsort = df[by].argsort()
if not ascending:
argsort = argsort[::-1]
df = df.loc[argsort].reset_index(drop=True)
return df, argsort


def apply_sort(*args, sort_idx, axis=0):
"""
Apply a sorting index.
"""
for arg in args:
assert arg.shape[axis] == len(sort_idx)
return [np.take(arg, sort_idx, axis=axis) for arg in args]
140 changes: 140 additions & 0 deletions tedana/metrics/collect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""
Collect metrics.
"""
import logging

import numpy as np
import pandas as pd

from .dependence import *
tsalo marked this conversation as resolved.
Show resolved Hide resolved
from ._utils import determine_signs, flip_components, sort_df, apply_sort


LGR = logging.getLogger(__name__)
RepLGR = logging.getLogger('REPORT')
RefLGR = logging.getLogger('REFERENCES')


def generate_metrics(comptable, data_cat, data_optcom, mixing, mask, tes, ref_img, mixing_z=None,
metrics=['kappa', 'rho'], sort_by='kappa', ascending=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's generally best to not have a mutable object as a default parameter. Can we set metrics to None and then do a check in the function?

"""
Fit TE-dependence and -independence models to components.

Parameters
----------
data_cat : (S x E x T) array_like
Input data, where `S` is samples, `E` is echos, and `T` is time
data_optcom : (S x T) array_like
Optimally combined data
mixing : (T x C) array_like
Mixing matrix for converting input data to component space, where `C`
is components and `T` is the same as in `data_cat`
mask : img_like
Mask
tes : list
List of echo times associated with `data_cat`, in milliseconds
ref_img : str or img_like
Reference image to dictate how outputs are saved to disk
reindex : bool, optional
Whether to sort components in descending order by Kappa. Default: False
mixing_z : (T x C) array_like, optional
Z-scored mixing matrix. Default: None
algorithm : {'kundu_v2', 'kundu_v3', None}, optional
Decision tree to be applied to metrics. Determines which maps will be
generated and stored in seldict. Default: None
label : :obj:`str` or None, optional
Prefix to apply to generated files. Default is None.
out_dir : :obj:`str`, optional
Output directory for generated files. Default is current working
directory.
verbose : :obj:`bool`, optional
Whether or not to generate additional files. Default is False.

Returns
-------
comptable : (C x X) :obj:`pandas.DataFrame`
Component metric table. One row for each component, with a column for
each metric. The index is the component number.
seldict : :obj:`dict` or None
Dictionary containing component-specific metric maps to be used for
component selection. If `algorithm` is None, then seldict will be None as
well.
betas : :obj:`numpy.ndarray`
mmix_new : :obj:`numpy.ndarray`
"""
RepLGR.info('The following metrics were calculated: {}.'.format(', '.join(metrics)))

if not (data_cat.shape[0] == data_optcom.shape[0] == mask.sum()):
raise ValueError('First dimensions (number of samples) of data_cat ({0}), '
'data_optcom ({1}), and mask ({2}) do not '
'match'.format(data_cat.shape[0], data_optcom.shape[0],
mask.shape[0]))
elif data_cat.shape[1] != len(tes):
raise ValueError('Second dimension of data_cat ({0}) does not match '
'number of echoes provided (tes; '
'{1})'.format(data_cat.shape[1], len(tes)))
elif not (data_cat.shape[2] == data_optcom.shape[1] == mixing.shape[0]):
raise ValueError('Number of volumes in data_cat ({0}), '
'data_optcom ({1}), and mixing ({2}) do not '
'match.'.format(data_cat.shape[2], data_optcom.shape[1], mixing.shape[0]))

mixing = mixing.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this copy is unnecessary, here, since the next time it's used is providing mixing to flip_components () which will trigger a copy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we should add a check for if mixing_z is None and assign mixing to mixing_z.

comptable = pd.DataFrame(index=np.arange(n_components, dtype=int))

# Metric maps
weights = calculate_weights(data_optcom, mixing_z)
signs = determine_signs(weights, axis=0)
weights, mixing = flip_components(weights, mixing, signs=signs)
optcom_betas = calculate_betas(data_optcom, mixing)
PSC = calculate_psc(data_optcom, optcom_betas)

# compute betas and means over TEs for TE-dependence analysis
Z_maps = calculate_z_maps(weights)
F_T2_maps, F_S0_maps = calculate_f_maps(mixing, data_cat, tes, Z_maps)

(Z_clmaps, F_T2_clmaps, F_S0_clmaps,
Br_T2_clmaps, Br_S0_clmaps) = spatial_cluster(
F_T2_maps, F_S0_maps, Z_maps, optcom_betas, mask, n_echos)

# Dependence metrics
if any([v in metrics for v in ['kappa', 'rho']]):
comptable = calculate_dependence_metrics(comptable, F_T2_maps, F_S0_maps, Z_maps)
tsalo marked this conversation as resolved.
Show resolved Hide resolved

# Generic metrics
if 'variance explained' in metrics:
comptable['variance explained'] = calculate_varex(optcom_betas)

if 'normalized variance explained' in metrics:
comptable['normalized variance explained'] = calculate_varex_norm(weights)

# Spatial metrics
if 'dice_FT2' in metrics:
comptable['dice_FT2'] = compute_dice(Br_T2_clmaps, F_T2_clmaps)

if 'dice_FS0' in metrics:
comptable['dice_FS0'] = compute_dice(Br_S0_clmaps, F_S0_clmaps)

if any([v in metrics for v in ['signal-noise_t', 'signal-noise_p']]):
(comptable['signal-noise_t'],
comptable['signal-noise_p']) = compute_signal_minus_noise_t(
Z_maps, Z_clmaps, F_T2_maps)

if 'countnoise' in metrics:
comptable['countnoise'] = compute_countnoise(Z_maps, Z_clmaps)

if 'countsigFT2' in metrics:
comptable['countsigFT2'] = compute_countsigFT2(F_T2_clmaps)

if 'countsigFS0' in metrics:
comptable['countsigFS0'] = compute_countsigFS0(F_S0_clmaps)

if 'd_table_score' in metrics:
comptable['d_table_score'] = generate_decision_table_score(
comptable['kappa'], comptable['dice_FT2'],
comptable['signal_minus_noise_t'], comptable['countnoise'],
comptable['countsigFT2'])

# TODO: move sorting out of this function and only return comptable
comptable, sort_idx = sort_df(comptable, by='kappa', ascending=ascending)
mixing = apply_sort(mixing, sort_idx=sort_idx, axis=1)
return comptable, mixing
Loading