Skip to content

Commit

Permalink
Update comparisons (#107)
Browse files Browse the repository at this point in the history
* sshfs annoying tmp files

* create unitests for all comparisons. change _m2e to remove ref from ens

*  reduced 3d area to increase speed



Former-commit-id: 6af09b0
  • Loading branch information
aaronspring authored May 3, 2019
1 parent 8695df7 commit dabbefc
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 28 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
.DS_Store
.*.un~
.fuse_hidden*

# Created by https://www.gitignore.io/api/python

Expand Down
16 changes: 11 additions & 5 deletions climpred/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _m2m(ds, supervector_dim='svd'):

def _m2e(ds, supervector_dim='svd'):
"""
Create two supervectors to compare all members to ensemble mean.
Create two supervectors to compare all members to ensemble mean while leaving out the reference when creating the forecasts.
Args:
ds (xarray object): xr.Dataset/xr.DataArray with member and ensemble
Expand All @@ -145,10 +145,16 @@ def _m2e(ds, supervector_dim='svd'):
reference (xarray object): reference.
"""
reference = ds.mean('member')
forecast, reference = xr.broadcast(ds, reference)
forecast = _stack_to_supervector(forecast, new_dim=supervector_dim)
reference = _stack_to_supervector(reference, new_dim=supervector_dim)
reference_list = []
forecast_list = []
for m in ds.member.values:
forecast = _drop_members(ds, rmd_member=[m]).mean('member')
reference = ds.sel(member=m).squeeze()
forecast, reference = xr.broadcast(forecast, reference)
forecast_list.append(forecast)
reference_list.append(reference)
reference = xr.concat(reference_list,'init').rename({'init': supervector_dim})
forecast = xr.concat(forecast_list,'init').rename({'init': supervector_dim})
return forecast, reference


Expand Down
128 changes: 109 additions & 19 deletions climpred/tests/test_comparisons.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
import pytest
import xarray as xr
from climpred.comparisons import _drop_members, _m2m
from xarray.testing import assert_equal

from climpred.comparisons import (_drop_members, _e2c, _m2c, _m2e, _m2m,
_stack_to_supervector)
from climpred.loadutils import open_dataset


Expand All @@ -19,20 +22,108 @@ def PM_da_control1d():
return da


@pytest.fixture
def PM_ds_ds1d():
ds = open_dataset('MPI-PM-DP-1D')
return ds
def m2e(ds, supervector_dim='svd'):
reference_list = []
forecast_list = []
for m in ds.member.values:
forecast = _drop_members(ds, rmd_member=[m]).mean('member')
reference = ds.sel(member=m).squeeze()
forecast, reference = xr.broadcast(forecast, reference)
forecast_list.append(forecast)
reference_list.append(reference)
reference = xr.concat(reference_list,
'init').rename({'init': supervector_dim})
forecast = xr.concat(forecast_list,
'init').rename({'init': supervector_dim})
return forecast, reference


@pytest.fixture
def PM_ds_control1d():
ds = open_dataset('MPI-control-1D')
return ds
def test_e2c(PM_da_ds1d):
"""Test ensemble_mean-to-control (which can be any other one member) (e2c) comparison basic functionality.
Clean comparison: Remove one control member from ensemble to use as reference. Take the remaining member mean as forecasts."""
ds = PM_da_ds1d
aforecast, areference = _e2c(ds)

control_member = [0]
supervector_dim = 'svd'
reference = ds.isel(member=control_member).squeeze()
if 'member' in reference.coords:
del reference['member']
reference = reference.rename({'init': supervector_dim})
# drop the member being reference
ds = _drop_members(ds, rmd_member=[ds.member.values[control_member]])
forecast = ds.mean('member')
forecast = forecast.rename({'init': supervector_dim})

eforecast, ereference = forecast, reference
# very weak testing on shape
assert eforecast.size == aforecast.size
assert ereference.size == areference.size

assert_equal(eforecast, aforecast)
assert_equal(ereference, areference)


def test_m2c(PM_da_ds1d):
"""Test many-to-control (which can be any other one member) (m2c) comparison basic functionality.
Clean comparison: Remove one control member from ensemble to use as reference. Take the remaining members as forecasts."""
ds = PM_da_ds1d
aforecast, areference = _m2c(ds)

supervector_dim = 'svd'
control_member = [0]
reference = ds.isel(member=control_member).squeeze()
# drop the member being reference
ds_dropped = _drop_members(ds, rmd_member=ds.member.values[control_member])
forecast, reference = xr.broadcast(ds_dropped, reference)
forecast = _stack_to_supervector(forecast, new_dim=supervector_dim)
reference = _stack_to_supervector(reference, new_dim=supervector_dim)

eforecast, ereference = forecast, reference
# very weak testing on shape
assert eforecast.size == aforecast.size
assert ereference.size == areference.size

assert_equal(eforecast, aforecast)
assert_equal(ereference, areference)


def test_m2e(PM_da_ds1d):
"""Test many-to-ensemble-mean (m2e) comparison basic functionality.
Clean comparison: Remove one member from ensemble to use as reference. Take the remaining members as forecasts."""
ds = PM_da_ds1d
aforecast, areference = _m2e(ds)

supervector_dim = 'svd'
reference_list = []
forecast_list = []
for m in ds.member.values:
forecast = _drop_members(ds, rmd_member=[m]).mean('member')
reference = ds.sel(member=m).squeeze()
forecast, reference = xr.broadcast(forecast, reference)
forecast_list.append(forecast)
reference_list.append(reference)
reference = xr.concat(reference_list,
'init').rename({'init': supervector_dim})
forecast = xr.concat(forecast_list,
'init').rename({'init': supervector_dim})

eforecast, ereference = forecast, reference
# very weak testing on shape
assert eforecast.size == aforecast.size
assert ereference.size == areference.size

assert_equal(eforecast, aforecast)
assert_equal(ereference, areference)


def test_m2m(PM_da_ds1d):
"Test m2m basic functionality of many to many comparison"
"""Test many-to-many (m2m) comparison basic functionality.
Clean comparison: Remove one member from ensemble to use as reference. Take the remaining members as forecasts."""
ds = PM_da_ds1d
aforecast, areference = _m2m(ds)

Expand All @@ -46,15 +137,14 @@ def test_m2m(PM_da_ds1d):
for m2 in ds_reduced.member:
for i in ds.init:
reference_list.append(reference.sel(init=i))
forecast_list.append(
ds_reduced.sel(member=m2, init=i))
reference = xr.concat(
reference_list, supervector_dim)
reference[supervector_dim] = np.arange(1, 1+reference.svd.size)
forecast = xr.concat(
forecast_list, supervector_dim)
forecast[supervector_dim] = np.arange(1, 1+forecast.svd.size)
forecast_list.append(ds_reduced.sel(member=m2, init=i))
reference = xr.concat(reference_list, supervector_dim)
reference[supervector_dim] = np.arange(1, 1 + reference.svd.size)
forecast = xr.concat(forecast_list, supervector_dim)
forecast[supervector_dim] = np.arange(1, 1 + forecast.svd.size)
eforecast, ereference = forecast, reference

# very weak testing here
assert eforecast.size == aforecast.size
assert ereference.size == areference.size
#assert_equal(eforecast,aforecast)
#assert_equal(ereference,areference)
8 changes: 4 additions & 4 deletions climpred/tests/test_perfect_model_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
def PM_da_ds3d():
da = open_dataset('MPI-PM-DP-3D')
# Box in South Atlantic with no NaNs.
da = da.isel(x=slice(0, 50), y=slice(125, 150))
da = da.isel(x=slice(0, 5), y=slice(145, 150))
return da['tos']


@pytest.fixture
def PM_da_control3d():
da = open_dataset('MPI-control-3D')
da = da.isel(x=slice(0, 50), y=slice(125, 150))
da = da.isel(x=slice(0, 5), y=slice(145, 150))
# fix to span 300yr control
t = list(np.arange(da.time.size))
da = da.isel(time=t*6)
Expand All @@ -34,14 +34,14 @@ def PM_da_control3d():
@pytest.fixture
def PM_ds_ds3d():
ds = open_dataset('MPI-PM-DP-3D')
ds = ds.isel(x=slice(0, 50), y=slice(125, 150))
ds = ds.isel(x=slice(0, 5), y=slice(145, 150))
return ds


@pytest.fixture
def PM_ds_control3d():
ds = open_dataset('MPI-control-3D')
ds = ds.isel(x=slice(0, 50), y=slice(125, 150))
ds = ds.isel(x=slice(0, 5), y=slice(145, 150))
t = list(np.arange(ds.time.size))
ds = ds.isel(time=t*6)
ds['time'] = np.arange(3000, 3000 + ds.time.size)
Expand Down

0 comments on commit dabbefc

Please sign in to comment.