Skip to content

Commit

Permalink
MRG, BUG: Fix biased cov estimation (mne-tools#7369)
Browse files Browse the repository at this point in the history
* BUG: Fix biased cov estimation

* FIX: Minor fix

* FIX: Link

* FIX: Dont require sklearn

* FIX: sklearn check

* MAINT: Simplify check

* FIX: Flake
  • Loading branch information
larsoner authored and AdoNunes committed Apr 6, 2020
1 parent 9f88805 commit 71abd8f
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 45 deletions.
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ Bug

- Fix bug in ``method='eLORETA'`` for :func:`mne.minimum_norm.apply_inverse` (and variants) to allow restricting source estimation to a label by `Luke Bloy`_

- Fix bug in :func:`mne.compute_covariance` and :func:`mne.compute_raw_covariance` where biased normalization (based on degrees of freedom) was used and ``cov.nfree`` was not set properly by `Eric Larson`_

- Fix :func:`mne.VectorSourceEstimate.normal` to account for cortical patch statistics using ``use_cps=True`` by `Eric Larson`_

- Fix ``pick_ori='normal'`` for :func:`mne.minimum_norm.apply_inverse` when the inverse was computed with ``loose=1.`` and the forward solution was not in surface orientation, by `Eric Larson`_
Expand Down
19 changes: 10 additions & 9 deletions mne/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
_check_option, eigh)
from . import viz

from .fixes import BaseEstimator, EmpiricalCovariance, _logdet
from .fixes import (BaseEstimator, EmpiricalCovariance, _logdet,
empirical_covariance, log_likelihood)


def _check_covs_algebra(cov1, cov2):
Expand Down Expand Up @@ -463,9 +464,10 @@ def compute_raw_covariance(raw, tmin=0, tmax=None, tstep=0.2, reject=None,
(instead of across epochs) for each channel.
"""
tmin = 0. if tmin is None else float(tmin)
tmax = raw.times[-1] if tmax is None else float(tmax)
dt = 1. / raw.info['sfreq']
tmax = raw.times[-1] + dt if tmax is None else float(tmax)
tstep = tmax - tmin if tstep is None else float(tstep)
tstep_m1 = tstep - 1. / raw.info['sfreq'] # inclusive!
tstep_m1 = tstep - dt # inclusive!
events = make_fixed_length_events(raw, 1, tmin, tmax, tstep)
logger.info('Using up to %s segment%s' % (len(events), _pl(events)))

Expand Down Expand Up @@ -504,7 +506,7 @@ def compute_raw_covariance(raw, tmin=0, tmax=None, tstep=0.2, reject=None,
ch_names = [raw.info['ch_names'][k] for k in picks]
bads = [b for b in raw.info['bads'] if b in ch_names]
return Covariance(data, ch_names, bads, raw.info['projs'],
nfree=n_samples)
nfree=n_samples - 1)
del picks, pick_mask

# This makes it equivalent to what we used to do (and do above for
Expand Down Expand Up @@ -878,7 +880,7 @@ def _unpack_epochs(epochs):
if keep_sample_mean is False:
cov = cov_data['empirical']['data']
# undo scaling
cov *= n_samples_tot
cov *= (n_samples_tot - 1)
# ... apply pre-computed class-wise normalization
for mean_cov in data_mean:
cov -= mean_cov
Expand All @@ -887,7 +889,7 @@ def _unpack_epochs(epochs):
covs = list()
for this_method, data in cov_data.items():
cov = Covariance(data.pop('data'), ch_names, info['bads'], projs,
nfree=n_samples_tot)
nfree=n_samples_tot - 1)

# add extra info
cov.update(method=this_method, **data)
Expand Down Expand Up @@ -1073,6 +1075,8 @@ def _compute_covariance_auto(data, method, info, method_params, cv,
loglik = None
# project back
cov = np.dot(eigvec.T, np.dot(cov, eigvec))
# undo bias
cov *= data.shape[0] / (data.shape[0] - 1)
# undo scaling
_undo_scaling_cov(cov, picks_list, scalings)
method_ = method[ei]
Expand Down Expand Up @@ -1191,7 +1195,6 @@ def __init__(self, info, grad=0.1, mag=0.1, eeg=0.1, seeg=0.1, ecog=0.1,

def fit(self, X):
"""Fit covariance model with classical diagonal regularization."""
from sklearn.covariance import EmpiricalCovariance
self.estimator_ = EmpiricalCovariance(
store_precision=self.store_precision,
assume_centered=self.assume_centered)
Expand Down Expand Up @@ -1232,7 +1235,6 @@ def __init__(self, store_precision, assume_centered,
def fit(self, X):
"""Fit covariance model with oracle shrinkage regularization."""
from sklearn.covariance import shrunk_covariance
from sklearn.covariance import EmpiricalCovariance
self.estimator_ = EmpiricalCovariance(
store_precision=self.store_precision,
assume_centered=self.assume_centered)
Expand Down Expand Up @@ -1277,7 +1279,6 @@ def fit(self, X):

def score(self, X_test, y=None):
"""Delegate to modified EmpiricalCovariance instance."""
from sklearn.covariance import empirical_covariance, log_likelihood
# compute empirical covariance of the test set
test_cov = empirical_covariance(X_test - self.estimator_.location_,
assume_centered=True)
Expand Down
78 changes: 53 additions & 25 deletions mne/tests/test_cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from mne.io.pick import _DATA_CH_TYPES_SPLIT
from mne.preprocessing import maxwell_filter
from mne.rank import _compute_rank_int
from mne.utils import (requires_version, run_tests_if_main,
from mne.utils import (requires_sklearn, run_tests_if_main,
catch_logging, assert_snr)

base_dir = op.join(op.dirname(__file__), '..', 'io', 'tests', 'data')
Expand Down Expand Up @@ -235,27 +235,48 @@ def test_io_cov(tmpdir):
read_cov(cov_badname)


@pytest.mark.parametrize('method', (None, ['empirical']))
@pytest.mark.parametrize('method', (None, 'empirical', 'shrunk'))
def test_cov_estimation_on_raw(method, tmpdir):
"""Test estimation from raw (typically empty room)."""
if method == 'shrunk':
try:
import sklearn # noqa: F401
except Exception as exp:
pytest.skip('sklearn is required, got %s' % (exp,))
raw = read_raw_fif(raw_fname, preload=True)
cov_mne = read_cov(erm_cov_fname)
method_params = dict(shrunk=dict(shrinkage=[0]))

# The pure-string uses the more efficient numpy-based method, the
# the list gets triaged to compute_covariance (should be equivalent
# but use more memory)
with pytest.warns(None): # can warn about EEG ref
cov = compute_raw_covariance(raw, tstep=None, method=method,
rank='full')
cov = compute_raw_covariance(
raw, tstep=None, method=method, rank='full',
method_params=method_params)
assert_equal(cov.ch_names, cov_mne.ch_names)
assert_equal(cov.nfree, cov_mne.nfree)
assert_snr(cov.data, cov_mne.data, 1e4)
assert_snr(cov.data, cov_mne.data, 1e6)

# test equivalence with np.cov
cov_np = np.cov(raw.copy().pick_channels(cov['names']).get_data(), ddof=1)
if method != 'shrunk': # can check all
off_diag = np.triu_indices(cov_np.shape[0])
else:
# We explicitly zero out off-diag entries between channel types,
# so let's just check MEG off-diag entries
off_diag = np.triu_indices(len(pick_types(raw.info, exclude=())))
for other in (cov_mne, cov):
assert_allclose(np.diag(cov_np), np.diag(other.data), rtol=5e-6)
assert_allclose(cov_np[off_diag], other.data[off_diag], rtol=4e-3)
assert_snr(cov.data, other.data, 1e6)

# tstep=0.2 (default)
with pytest.warns(None): # can warn about EEG ref
cov = compute_raw_covariance(raw, method=method, rank='full')
assert_equal(cov.nfree, cov_mne.nfree - 119) # cutoff some samples
assert_snr(cov.data, cov_mne.data, 1e2)
cov = compute_raw_covariance(raw, method=method, rank='full',
method_params=method_params)
assert_equal(cov.nfree, cov_mne.nfree - 120) # cutoff some samples
assert_snr(cov.data, cov_mne.data, 170)

# test IO when computation done in Python
cov.save(tmpdir.join('test-cov.fif')) # test saving
Expand All @@ -268,27 +289,30 @@ def test_cov_estimation_on_raw(method, tmpdir):
raw_pick = raw.copy().pick_channels(raw.ch_names[:5])
raw_pick.info.normalize_proj()
cov = compute_raw_covariance(raw_pick, tstep=None, method=method,
rank='full')
rank='full', method_params=method_params)
assert cov_mne.ch_names[:5] == cov.ch_names
assert_snr(cov.data, cov_mne.data[:5, :5], 1e4)
cov = compute_raw_covariance(raw_pick, method=method, rank='full')
assert_snr(cov.data, cov_mne.data[:5, :5], 5e6)
cov = compute_raw_covariance(raw_pick, method=method, rank='full',
method_params=method_params)
assert_snr(cov.data, cov_mne.data[:5, :5], 90) # cutoff samps
# make sure we get a warning with too short a segment
raw_2 = read_raw_fif(raw_fname).crop(0, 1)
with pytest.warns(RuntimeWarning, match='Too few samples'):
cov = compute_raw_covariance(raw_2, method=method)
cov = compute_raw_covariance(raw_2, method=method,
method_params=method_params)
# no epochs found due to rejection
pytest.raises(ValueError, compute_raw_covariance, raw, tstep=None,
method='empirical', reject=dict(eog=200e-6))
# but this should work
cov = compute_raw_covariance(raw.copy().crop(0, 10.),
tstep=None, method=method,
reject=dict(eog=1000e-6),
verbose='error')
with pytest.warns(None): # sklearn
cov = compute_raw_covariance(
raw.copy().crop(0, 10.), tstep=None, method=method,
reject=dict(eog=1000e-6), method_params=method_params,
verbose='error')


@pytest.mark.slowtest
@requires_version('sklearn', '0.15')
@requires_sklearn
def test_cov_estimation_on_raw_reg():
"""Test estimation from raw with regularization."""
raw = read_raw_fif(raw_fname, preload=True)
Expand Down Expand Up @@ -328,7 +352,10 @@ def test_cov_estimation_with_triggers(rank, tmpdir):
reject=reject, preload=True)

cov = compute_covariance(epochs, keep_sample_mean=True)
_assert_cov(cov, read_cov(cov_km_fname))
cov_km = read_cov(cov_km_fname)
# adjust for nfree bug
cov_km['nfree'] -= 1
_assert_cov(cov, cov_km)

# Test with tmin and tmax (different but not too much)
cov_tmin_tmax = compute_covariance(epochs, tmin=-0.19, tmax=-0.01)
Expand All @@ -347,6 +374,7 @@ def test_cov_estimation_with_triggers(rank, tmpdir):

# cov with keep_sample_mean=False using a list of epochs
cov = compute_covariance(epochs, keep_sample_mean=False)
assert cov_km.nfree == cov.nfree
_assert_cov(cov, read_cov(cov_fname), nfree=False)

method_params = {'empirical': {'assume_centered': False}}
Expand Down Expand Up @@ -450,7 +478,7 @@ def test_regularized_covariance():
assert_allclose(data, evoked.data, atol=1e-20)


@requires_version('sklearn', '0.15')
@requires_sklearn
def test_auto_low_rank():
"""Test probabilistic low rank estimators."""
n_samples, n_features, rank = 400, 10, 5
Expand Down Expand Up @@ -492,7 +520,7 @@ def get_data(n_samples, n_features, rank, sigma):

@pytest.mark.slowtest
@pytest.mark.parametrize('rank', ('full', None, 'info'))
@requires_version('sklearn', '0.15')
@requires_sklearn
def test_compute_covariance_auto_reg(rank):
"""Test automated regularization."""
raw = read_raw_fif(raw_fname, preload=True)
Expand Down Expand Up @@ -530,8 +558,8 @@ def test_compute_covariance_auto_reg(rank):
cov_b['data'][diag_mask])

# but the rest is the same
assert_array_equal(cov_a['data'][off_diag_mask],
cov_b['data'][off_diag_mask])
assert_allclose(cov_a['data'][off_diag_mask],
cov_b['data'][off_diag_mask], rtol=1e-12)

else:
# and here we have shrinkage everywhere.
Expand Down Expand Up @@ -604,7 +632,7 @@ def raw_epochs_events():
return (raw, epochs, events)


@requires_version('sklearn', '0.15')
@requires_sklearn
@pytest.mark.parametrize('rank', (None, 'full', 'info'))
def test_low_rank_methods(rank, raw_epochs_events):
"""Test low-rank covariance matrix estimation."""
Expand Down Expand Up @@ -639,7 +667,7 @@ def test_low_rank_methods(rank, raw_epochs_events):
(rank, method)


@requires_version('sklearn', '0.15')
@requires_sklearn
def test_low_rank_cov(raw_epochs_events):
"""Test additional properties of low rank computations."""
raw, epochs, events = raw_epochs_events
Expand Down Expand Up @@ -713,7 +741,7 @@ def test_low_rank_cov(raw_epochs_events):


@testing.requires_testing_data
@requires_version('sklearn', '0.15')
@requires_sklearn
def test_cov_ctf():
"""Test basic cov computation on ctf data with/without compensation."""
raw = read_raw_ctf(ctf_fname).crop(0., 2.).load_data()
Expand Down
15 changes: 4 additions & 11 deletions mne/utils/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,6 @@ def requires_module(function, name, call=None):
raise ImportError
"""

_sklearn_call = """
required_version = '0.14'
import sklearn
version = LooseVersion(sklearn.__version__)
if version < required_version:
raise ImportError
"""

_mayavi_call = """
with warnings.catch_warnings(record=True): # traits
from mayavi import mlab
Expand All @@ -152,7 +144,7 @@ def requires_module(function, name, call=None):

requires_pandas = partial(requires_module, name='pandas', call=_pandas_call)
requires_pylsl = partial(requires_module, name='pylsl')
requires_sklearn = partial(requires_module, name='sklearn', call=_sklearn_call)
requires_sklearn = partial(requires_module, name='sklearn')
requires_mayavi = partial(requires_module, name='mayavi', call=_mayavi_call)
requires_mne = partial(requires_module, name='MNE-C', call=_mne_call)

Expand Down Expand Up @@ -452,8 +444,9 @@ def assert_meg_snr(actual, desired, min_tol, med_tol=500., chpi_med_tol=500.,

def assert_snr(actual, desired, tol):
"""Assert actual and desired arrays are within some SNR tolerance."""
snr = (linalg.norm(desired, ord='fro') /
linalg.norm(desired - actual, ord='fro'))
with np.errstate(divide='ignore'): # allow infinite
snr = (linalg.norm(desired, ord='fro') /
linalg.norm(desired - actual, ord='fro'))
assert snr >= tol, '%f < %f' % (snr, tol)


Expand Down

0 comments on commit 71abd8f

Please sign in to comment.