Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update comparisons #107

Merged
merged 4 commits into from
May 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
.DS_Store
.*.un~
.fuse_hidden*

# Created by https://www.gitignore.io/api/python

Expand Down
16 changes: 11 additions & 5 deletions climpred/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _m2m(ds, supervector_dim='svd'):

def _m2e(ds, supervector_dim='svd'):
"""
Create two supervectors to compare all members to ensemble mean.
Create two supervectors to compare all members to ensemble mean while leaving out the reference when creating the forecasts.

Args:
ds (xarray object): xr.Dataset/xr.DataArray with member and ensemble
Expand All @@ -145,10 +145,16 @@ def _m2e(ds, supervector_dim='svd'):
reference (xarray object): reference.

"""
reference = ds.mean('member')
forecast, reference = xr.broadcast(ds, reference)
forecast = _stack_to_supervector(forecast, new_dim=supervector_dim)
reference = _stack_to_supervector(reference, new_dim=supervector_dim)
reference_list = []
forecast_list = []
for m in ds.member.values:
forecast = _drop_members(ds, rmd_member=[m]).mean('member')
reference = ds.sel(member=m).squeeze()
forecast, reference = xr.broadcast(forecast, reference)
forecast_list.append(forecast)
reference_list.append(reference)
reference = xr.concat(reference_list,'init').rename({'init': supervector_dim})
forecast = xr.concat(forecast_list,'init').rename({'init': supervector_dim})
return forecast, reference


Expand Down
128 changes: 109 additions & 19 deletions climpred/tests/test_comparisons.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
import pytest
import xarray as xr
from climpred.comparisons import _drop_members, _m2m
from xarray.testing import assert_equal

from climpred.comparisons import (_drop_members, _e2c, _m2c, _m2e, _m2m,
_stack_to_supervector)
from climpred.loadutils import open_dataset


Expand All @@ -19,20 +22,108 @@ def PM_da_control1d():
return da


@pytest.fixture
def PM_ds_ds1d():
ds = open_dataset('MPI-PM-DP-1D')
return ds
def m2e(ds, supervector_dim='svd'):
reference_list = []
forecast_list = []
for m in ds.member.values:
forecast = _drop_members(ds, rmd_member=[m]).mean('member')
reference = ds.sel(member=m).squeeze()
forecast, reference = xr.broadcast(forecast, reference)
forecast_list.append(forecast)
reference_list.append(reference)
reference = xr.concat(reference_list,
'init').rename({'init': supervector_dim})
forecast = xr.concat(forecast_list,
'init').rename({'init': supervector_dim})
return forecast, reference


@pytest.fixture
def PM_ds_control1d():
ds = open_dataset('MPI-control-1D')
return ds
def test_e2c(PM_da_ds1d):
"""Test ensemble_mean-to-control (which can be any other one member) (e2c) comparison basic functionality.

Clean comparison: Remove one control member from ensemble to use as reference. Take the remaining member mean as forecasts."""
ds = PM_da_ds1d
aforecast, areference = _e2c(ds)

control_member = [0]
supervector_dim = 'svd'
reference = ds.isel(member=control_member).squeeze()
if 'member' in reference.coords:
del reference['member']
reference = reference.rename({'init': supervector_dim})
# drop the member being reference
ds = _drop_members(ds, rmd_member=[ds.member.values[control_member]])
forecast = ds.mean('member')
forecast = forecast.rename({'init': supervector_dim})

eforecast, ereference = forecast, reference
# very weak testing on shape
assert eforecast.size == aforecast.size
assert ereference.size == areference.size

assert_equal(eforecast, aforecast)
assert_equal(ereference, areference)


def test_m2c(PM_da_ds1d):
"""Test many-to-control (which can be any other one member) (m2c) comparison basic functionality.

Clean comparison: Remove one control member from ensemble to use as reference. Take the remaining members as forecasts."""
ds = PM_da_ds1d
aforecast, areference = _m2c(ds)

supervector_dim = 'svd'
control_member = [0]
reference = ds.isel(member=control_member).squeeze()
# drop the member being reference
ds_dropped = _drop_members(ds, rmd_member=ds.member.values[control_member])
forecast, reference = xr.broadcast(ds_dropped, reference)
forecast = _stack_to_supervector(forecast, new_dim=supervector_dim)
reference = _stack_to_supervector(reference, new_dim=supervector_dim)

eforecast, ereference = forecast, reference
# very weak testing on shape
assert eforecast.size == aforecast.size
assert ereference.size == areference.size

assert_equal(eforecast, aforecast)
assert_equal(ereference, areference)


def test_m2e(PM_da_ds1d):
"""Test many-to-ensemble-mean (m2e) comparison basic functionality.

Clean comparison: Remove one member from ensemble to use as reference. Take the remaining members as forecasts."""
ds = PM_da_ds1d
aforecast, areference = _m2e(ds)

supervector_dim = 'svd'
reference_list = []
forecast_list = []
for m in ds.member.values:
forecast = _drop_members(ds, rmd_member=[m]).mean('member')
reference = ds.sel(member=m).squeeze()
forecast, reference = xr.broadcast(forecast, reference)
forecast_list.append(forecast)
reference_list.append(reference)
reference = xr.concat(reference_list,
'init').rename({'init': supervector_dim})
forecast = xr.concat(forecast_list,
'init').rename({'init': supervector_dim})

eforecast, ereference = forecast, reference
# very weak testing on shape
assert eforecast.size == aforecast.size
assert ereference.size == areference.size

assert_equal(eforecast, aforecast)
assert_equal(ereference, areference)


def test_m2m(PM_da_ds1d):
"Test m2m basic functionality of many to many comparison"
"""Test many-to-many (m2m) comparison basic functionality.

Clean comparison: Remove one member from ensemble to use as reference. Take the remaining members as forecasts."""
ds = PM_da_ds1d
aforecast, areference = _m2m(ds)

Expand All @@ -46,15 +137,14 @@ def test_m2m(PM_da_ds1d):
for m2 in ds_reduced.member:
for i in ds.init:
reference_list.append(reference.sel(init=i))
forecast_list.append(
ds_reduced.sel(member=m2, init=i))
reference = xr.concat(
reference_list, supervector_dim)
reference[supervector_dim] = np.arange(1, 1+reference.svd.size)
forecast = xr.concat(
forecast_list, supervector_dim)
forecast[supervector_dim] = np.arange(1, 1+forecast.svd.size)
forecast_list.append(ds_reduced.sel(member=m2, init=i))
reference = xr.concat(reference_list, supervector_dim)
reference[supervector_dim] = np.arange(1, 1 + reference.svd.size)
forecast = xr.concat(forecast_list, supervector_dim)
forecast[supervector_dim] = np.arange(1, 1 + forecast.svd.size)
eforecast, ereference = forecast, reference

# very weak testing here
assert eforecast.size == aforecast.size
assert ereference.size == areference.size
#assert_equal(eforecast,aforecast)
#assert_equal(ereference,areference)
8 changes: 4 additions & 4 deletions climpred/tests/test_perfect_model_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
def PM_da_ds3d():
da = open_dataset('MPI-PM-DP-3D')
# Box in South Atlantic with no NaNs.
da = da.isel(x=slice(0, 50), y=slice(125, 150))
da = da.isel(x=slice(0, 5), y=slice(145, 150))
return da['tos']


@pytest.fixture
def PM_da_control3d():
da = open_dataset('MPI-control-3D')
da = da.isel(x=slice(0, 50), y=slice(125, 150))
da = da.isel(x=slice(0, 5), y=slice(145, 150))
# fix to span 300yr control
t = list(np.arange(da.time.size))
da = da.isel(time=t*6)
Expand All @@ -34,14 +34,14 @@ def PM_da_control3d():
@pytest.fixture
def PM_ds_ds3d():
ds = open_dataset('MPI-PM-DP-3D')
ds = ds.isel(x=slice(0, 50), y=slice(125, 150))
ds = ds.isel(x=slice(0, 5), y=slice(145, 150))
return ds


@pytest.fixture
def PM_ds_control3d():
ds = open_dataset('MPI-control-3D')
ds = ds.isel(x=slice(0, 50), y=slice(125, 150))
ds = ds.isel(x=slice(0, 5), y=slice(145, 150))
t = list(np.arange(ds.time.size))
ds = ds.isel(time=t*6)
ds['time'] = np.arange(3000, 3000 + ds.time.size)
Expand Down