Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor bootstrap: first resample, then metric #355

Merged
merged 43 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
c65c04d
refactor bootstrap: first resample, then metric
Apr 19, 2020
151c942
Merge branch 'master' into pr-355/bradyrx/AS_bootstrap_metric_refactor
Apr 23, 2020
fd35e8e
Merge branch 'master' into AS_bootstrap_metric_refactor
aaronspring Apr 23, 2020
8526ca9
fix signif notebook
Apr 23, 2020
11c140f
fix asv warnings
Apr 23, 2020
3a7cf0a
fix too large iteration number for bootstrapping CI
Apr 23, 2020
d2f4038
rm print
Apr 23, 2020
bb3d71b
rm print
aaronspring Apr 23, 2020
8236c7b
Merge branch 'master' into AS_bootstrap_metric_refactor
aaronspring Apr 29, 2020
5f8d859
Merge branch 'master' into AS_bootstrap_metric_refactor
aaronspring May 5, 2020
756af52
Merge branch 'AS_bootstrap_metric_refactor' of github.com:bradyrx/cli…
May 5, 2020
f1d3d75
Merge branch 'AS_bootstrap_metric_refactor' of github.com:bradyrx/cli…
May 5, 2020
977f91b
refactor bootstrap_hindcast_over_init_dim
May 5, 2020
e9cf5d0
asv self.iterations
May 5, 2020
bb58e6f
dont create more uninit members PM
May 5, 2020
376981a
bootstrap_hind init warning
May 5, 2020
e466123
uninit PM loop
May 5, 2020
a1cf7d8
remove warnings transpose coords
May 5, 2020
da1e5a2
bootstrap PM require 50 members to bootstrap from
May 5, 2020
a910f6b
bugfix
May 5, 2020
0a1978e
fix issue by compute (#362)
aaronspring May 9, 2020
a0d95b7
Pr 355/bradyrx/as bootstrap metric refactor (#363)
aaronspring May 11, 2020
6da48a7
Pr 355/bradyrx/as bootstrap metric refactor (#365)
aaronspring May 11, 2020
852a2a0
Pr 355/bradyrx/as bootstrap metric refactor (#366)
aaronspring May 11, 2020
48316f4
Pr 355/bradyrx/as bootstrap metric refactor (#367)
aaronspring May 12, 2020
4d97e2f
first uninit skill and copy(deep=True)
aaronspring May 16, 2020
15a20bb
lint
May 16, 2020
7e196b4
flake fstr and l -> lead
May 16, 2020
c9abd5e
lint
May 16, 2020
5a81542
transform_kwargs only for xr.DataArray
May 23, 2020
6eadd93
resample_iterations_idx for small eager data, resample_iterations for…
May 23, 2020
1941d6d
cleanup a bit
May 23, 2020
e417d2d
cleanup
May 23, 2020
92ded3a
fix get_path and cleanup bootstrap
May 23, 2020
ab8444d
refactor bootstrap.py: bootstrap stats to end of file
May 23, 2020
80870b1
add bootstrap tests and rm get_chunksize
May 23, 2020
ef189ad
fix performance warning ds
May 23, 2020
8e9170a
asv distributed
May 23, 2020
03838f0
fix hindcast iterations label
May 23, 2020
ef2d332
bootstrap_hindcast uninit fix
May 23, 2020
8bbba54
increase coverage
May 24, 2020
5e3fc1e
bootstrap uninit resample less memory consumption, dim_max
May 25, 2020
2d4fa2b
crease testing and fix persistence coords copy iterations bug
May 25, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ Internals/Minor Fixes
climpred.stats (:pr:`354`) `Aaron Spring`_.
- Require ``cftime v1.1.2``, which modifies their object handling to create 200-400x
speedups in some basic operations. (:pr:`356`) `Riley X. Brady`_.
- Resample first and then calculate skill in
:py:func:`~climpred.bootstrap.bootstrap_perfect_model` and
:py:func:`~climpred.bootstrap.bootstrap_hindcast` (:pr:`355`) `Aaron Spring`_.


Documentation
Expand Down
67 changes: 51 additions & 16 deletions asv_bench/benchmarks/benchmarks_hindcast.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import xarray as xr
from dask.distributed import Client

from climpred.bootstrap import bootstrap_hindcast
from climpred.metrics import PROBABILISTIC_METRICS
from climpred.prediction import compute_hindcast

from . import ensure_loaded, parameterized, randn, requires_dask
Expand All @@ -11,7 +13,7 @@
# only take comparisons compatible with probabilistic metrics
HINDCAST_COMPARISONS = ['m2o']

ITERATIONS = 8
ITERATIONS = 16


class Generate:
Expand All @@ -30,16 +32,20 @@ def make_hind_obs(self):
self.uninit = xr.Dataset()

self.nmember = 3
self.nlead = 3
self.nx = 64
self.ny = 64
self.nlead = 5
self.nx = 72
self.ny = 36
self.iterations = ITERATIONS
self.init_start = 1960
self.init_end = 2000
self.ninit = self.init_end - self.init_start
self.client = None

FRAC_NAN = 0.0

inits = np.arange(self.init_start, self.init_end)
inits = xr.cftime_range(
start=str(self.init_start), end=str(self.init_end - 1), freq='YS'
)
leads = np.arange(1, 1 + self.nlead)
members = np.arange(1, 1 + self.nmember)

Expand Down Expand Up @@ -86,6 +92,9 @@ def make_hind_obs(self):
)

self.hind.attrs = {'history': 'created for xarray benchmarking'}
self.hind.lead.attrs['units'] = 'years'
self.uninit.time.attrs['units'] = 'years'
self.observations.time.attrs['units'] = 'years'


class Compute(Generate):
Expand All @@ -99,48 +108,60 @@ def setup(self, *args, **kwargs):
@parameterized(['metric', 'comparison'], (METRICS, HINDCAST_COMPARISONS))
def time_compute_hindcast(self, metric, comparison):
"""Take time for `compute_hindcast`."""
dim = 'member' if metric in PROBABILISTIC_METRICS else 'init'
ensure_loaded(
compute_hindcast(
self.hind, self.observations, metric=metric, comparison=comparison,
self.hind,
self.observations,
metric=metric,
comparison=comparison,
dim=dim,
)
)

@parameterized(['metric', 'comparison'], (METRICS, HINDCAST_COMPARISONS))
def peakmem_compute_hindcast(self, metric, comparison):
"""Take memory peak for `compute_hindcast`."""
dim = 'member' if metric in PROBABILISTIC_METRICS else 'init'
ensure_loaded(
compute_hindcast(
self.hind, self.observations, metric=metric, comparison=comparison,
self.hind,
self.observations,
metric=metric,
comparison=comparison,
dim=dim,
)
)

@parameterized(['metric', 'comparison'], (METRICS, HINDCAST_COMPARISONS))
def time_bootstrap_hindcast(self, metric, comparison):
"""Take time for `bootstrap_hindcast`."""
dim = 'member' if metric in PROBABILISTIC_METRICS else 'init'
ensure_loaded(
bootstrap_hindcast(
self.hind,
self.uninit,
self.observations,
metric=metric,
comparison=comparison,
iterations=ITERATIONS,
dim='member',
iterations=self.iterations,
dim=dim,
)
)

@parameterized(['metric', 'comparison'], (METRICS, HINDCAST_COMPARISONS))
def peakmem_bootstrap_hindcast(self, metric, comparison):
"""Take memory peak for `bootstrap_hindcast`."""
dim = 'member' if metric in PROBABILISTIC_METRICS else 'init'
ensure_loaded(
bootstrap_hindcast(
self.hind,
self.uninit,
self.observations,
metric=metric,
comparison=comparison,
iterations=ITERATIONS,
dim='member',
iterations=self.iterations,
dim=dim,
)
)

Expand All @@ -155,11 +176,24 @@ def setup(self, *args, **kwargs):
# https://github.com/pydata/xarray/blob/stable/asv_bench/benchmarks/rolling.py
super().setup(**kwargs)
# chunk along a spatial dimension to enable embarrasingly parallel computation
self.hind = self.hind['var'].chunk({'lon': self.nx // ITERATIONS})
self.observations = self.observations['var'].chunk(
{'lon': self.nx // ITERATIONS}
)
self.uninit = self.uninit['var'].chunk({'lon': self.nx // ITERATIONS})
self.hind = self.hind['var'].chunk()
self.observations = self.observations['var'].chunk()
self.uninit = self.uninit['var'].chunk()


class ComputeDaskDistributed(ComputeDask):
def setup(self, *args, **kwargs):
"""Benchmark time and peak memory of `compute_hindcast` and
`bootstrap_hindcast`. This executes the same tests as `Compute` but
on chunked data with dask.distributed.Client."""
requires_dask()
# magic taken from
# https://github.com/pydata/xarray/blob/stable/asv_bench/benchmarks/rolling.py
super().setup(**kwargs)
self.client = Client()

def cleanup(self):
self.client.shutdown()


class ComputeSmall(Compute):
Expand All @@ -176,3 +210,4 @@ def setup(self, *args, **kwargs):
self.hind = self.hind.mean(spatial_dims)
self.observations = self.observations.mean(spatial_dims)
self.uninit = self.uninit.mean(spatial_dims)
self.iterations = 500
70 changes: 52 additions & 18 deletions asv_bench/benchmarks/benchmarks_perfect_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import xarray as xr
from dask.distributed import Client

from climpred.bootstrap import bootstrap_perfect_model
from climpred.metrics import PROBABILISTIC_METRICS
from climpred.prediction import compute_perfect_model

from . import ensure_loaded, parameterized, randn, requires_dask
Expand All @@ -11,7 +13,7 @@
# only take comparisons compatible with probabilistic metrics
PM_COMPARISONS = ['m2m', 'm2c']

ITERATIONS = 8
ITERATIONS = 16


class Generate:
Expand All @@ -26,23 +28,32 @@ def make_initialized_control(self):
perfect-model experiment."""
self.ds = xr.Dataset()
self.control = xr.Dataset()
self.nmember = 3
self.ninit = 4
self.nlead = 3
self.nx = 64
self.ny = 64
self.nmember = 5
self.ninit = 6
self.nlead = 10
self.iterations = ITERATIONS
self.nx = 72
self.ny = 36
self.control_start = 3000
self.control_end = 3300
self.ntime = 300
self.ntime = self.control_end - self.control_start
self.client = None

FRAC_NAN = 0.0

times = np.arange(self.control_start, self.control_end)
times = xr.cftime_range(
start=str(self.control_start),
periods=self.ntime,
freq='YS',
calendar='noleap',
)
leads = np.arange(1, 1 + self.nlead)
members = np.arange(1, 1 + self.nmember)
inits = (
np.random.choice(self.control_end - self.control_start, self.ninit)
+ self.control_start
inits = xr.cftime_range(
start=str(self.control_start),
periods=self.ninit,
freq='10YS',
calendar='noleap',
)

lons = xr.DataArray(
Expand Down Expand Up @@ -80,6 +91,8 @@ def make_initialized_control(self):
)

self.ds.attrs = {'history': 'created for xarray benchmarking'}
self.ds.lead.attrs['units'] = 'years'
self.control.time.attrs['units'] = 'years'


class Compute(Generate):
Expand All @@ -94,44 +107,50 @@ def setup(self, *args, **kwargs):
@parameterized(['metric', 'comparison'], (METRICS, PM_COMPARISONS))
def time_compute_perfect_model(self, metric, comparison):
"""Take time for `compute_perfect_model`."""
dim = 'member' if metric in PROBABILISTIC_METRICS else None
ensure_loaded(
compute_perfect_model(
self.ds, self.control, metric=metric, comparison=comparison
self.ds, self.control, metric=metric, comparison=comparison, dim=dim
)
)

@parameterized(['metric', 'comparison'], (METRICS, PM_COMPARISONS))
def peakmem_compute_perfect_model(self, metric, comparison):
"""Take memory peak for `compute_perfect_model`."""
dim = 'member' if metric in PROBABILISTIC_METRICS else None
ensure_loaded(
compute_perfect_model(
self.ds, self.control, metric=metric, comparison=comparison
self.ds, self.control, metric=metric, comparison=comparison, dim=dim
)
)

@parameterized(['metric', 'comparison'], (METRICS, PM_COMPARISONS))
def time_bootstrap_perfect_model(self, metric, comparison):
"""Take time for `bootstrap_perfect_model`."""
dim = 'member' if metric in PROBABILISTIC_METRICS else None
ensure_loaded(
bootstrap_perfect_model(
self.ds,
self.control,
metric=metric,
comparison=comparison,
iterations=ITERATIONS,
iterations=self.iterations,
dim=dim,
)
)

@parameterized(['metric', 'comparison'], (METRICS, PM_COMPARISONS))
def peakmem_bootstrap_perfect_model(self, metric, comparison):
"""Take memory peak for `bootstrap_perfect_model`."""
dim = 'member' if metric in PROBABILISTIC_METRICS else None
ensure_loaded(
bootstrap_perfect_model(
self.ds,
self.control,
metric=metric,
comparison=comparison,
iterations=ITERATIONS,
iterations=self.iterations,
dim=dim,
)
)

Expand All @@ -146,8 +165,23 @@ def setup(self, *args, **kwargs):
# https://github.com/pydata/xarray/blob/stable/asv_bench/benchmarks/rolling.py
super().setup(**kwargs)
# chunk along a spatial dimension to enable embarrasingly parallel computation
self.ds = self.ds['var'].chunk({'lon': self.nx // ITERATIONS})
self.control = self.control['var'].chunk({'lon': self.nx // ITERATIONS})
self.ds = self.ds['var'].chunk()
self.control = self.control['var'].chunk()


class ComputeDaskDistributed(ComputeDask):
def setup(self, *args, **kwargs):
"""Benchmark time and peak memory of `compute_perfect_model` and
`bootstrap_perfect_model`. This executes the same tests as `Compute` but
on chunked data with dask.distributed.Client."""
requires_dask()
# magic taken from
# https://github.com/pydata/xarray/blob/stable/asv_bench/benchmarks/rolling.py
super().setup(**kwargs)
self.client = Client()

def cleanup(self):
self.client.shutdown()


class ComputeSmall(Compute):
Expand All @@ -158,7 +192,7 @@ def setup(self, *args, **kwargs):
# magic taken from
# https://github.com/pydata/xarray/blob/stable/asv_bench/benchmarks/rolling.py
super().setup(**kwargs)
# chunk along a spatial dimension to enable embarrasingly parallel computation
spatial_dims = ['lon', 'lat']
self.ds = self.ds.mean(spatial_dims)
self.control = self.control.mean(spatial_dims)
self.iterations = 500
20 changes: 12 additions & 8 deletions climpred/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,12 @@ def _maximize_alignment(init_lead_matrix, all_verifs, leads):
# Probably a way to do this more efficiently since we're doing essentially
# the same thing at each step.
verif_dates = {
l: lead_dependent_verif_dates.sel(lead=l).dropna('time').to_index()
for l in leads
lead: lead_dependent_verif_dates.sel(lead=lead).dropna('time').to_index()
for lead in leads
}
inits = {
l: lead_dependent_verif_dates.sel(lead=l).dropna('time')['time'] for l in leads
lead: lead_dependent_verif_dates.sel(lead=lead).dropna('time')['time']
for lead in leads
}
return inits, verif_dates

Expand All @@ -106,9 +107,10 @@ def _same_inits_alignment(init_lead_matrix, valid_inits, all_verifs, leads, n, f
"""
verifies_at_all_leads = init_lead_matrix.isin(all_verifs).all('lead')
inits = valid_inits.where(verifies_at_all_leads, drop=True)
inits = {l: inits for l in leads}
inits = {lead: inits for lead in leads}
verif_dates = {
l: shift_cftime_index(inits[l], 'time', n, freq) for (l, n) in zip(leads, n)
lead: shift_cftime_index(inits[lead], 'time', n, freq)
for (lead, n) in zip(leads, n)
}
return inits, verif_dates

Expand All @@ -132,10 +134,12 @@ def _same_verifs_alignment(init_lead_matrix, valid_inits, all_verifs, leads, n,
verif_dates = xr.concat(common_set_of_verifs, 'time').to_index()
inits_that_verify_with_verif_dates = init_lead_matrix.isin(verif_dates)
inits = {
l: valid_inits.where(inits_that_verify_with_verif_dates.sel(lead=l), drop=True)
for l in leads
lead: valid_inits.where(
inits_that_verify_with_verif_dates.sel(lead=lead), drop=True
)
for lead in leads
}
verif_dates = {l: verif_dates for l in leads}
verif_dates = {lead: verif_dates for lead in leads}
return inits, verif_dates


Expand Down
Loading