Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EWMA alignement with pandas and speedup #53

Merged
merged 46 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
f3d1e25
[ewma] better management of nans when time-series starts with nans. U…
eserie Mar 9, 2022
1c8f86d
add ignore_na and return_info options
eserie Mar 10, 2022
c8a7625
simplify implementation
eserie Mar 10, 2022
96437dd
fix
eserie Mar 10, 2022
bea91a8
clean code
eserie Mar 10, 2022
2076209
add min_periods option
eserie Mar 10, 2022
d63cf06
add com parameter and docstring
eserie Mar 11, 2022
f8d8e1d
rename haiku param alpha in com
eserie Mar 11, 2022
12ea2c5
use logcom as haiku parameter. Add test with training
eserie Mar 11, 2022
6eec139
refine test for training ewma
eserie Mar 11, 2022
4227cf5
format code
eserie Mar 12, 2022
f15b30d
replace isnan_x by ~is_observation
eserie Mar 13, 2022
a22a0c1
Align implementation with pandas
eserie Mar 13, 2022
dc8454e
set dtype int for nobs
eserie Mar 14, 2022
a07f23e
decrease linearly when adjust=linear
eserie Mar 14, 2022
e094ed3
add numba ewma with linear adjustement
eserie Mar 14, 2022
59c931f
add EWMA demo notebook
eserie Mar 14, 2022
78fb86b
add some tests
eserie Mar 14, 2022
59990a4
correct numba implementation with state management
eserie Mar 14, 2022
1e4c702
update notebook
eserie Mar 14, 2022
7a140ea
add dataframe online.ewma
eserie Mar 14, 2022
18e61c2
format
eserie Mar 14, 2022
66ad974
reactivate test
eserie Mar 14, 2022
16db94a
state as dataframe, check dtypes
eserie Mar 14, 2022
fcf9d41
format
eserie Mar 14, 2022
9a55586
add Series accessor + format
eserie Mar 16, 2022
11279a2
fix mypy
eserie Mar 16, 2022
c985fa0
move numba modules
eserie Mar 16, 2022
dff0b22
remove line
eserie Mar 16, 2022
9420ed6
remove comment
eserie Mar 16, 2022
4120aca
correct flake8
eserie Mar 16, 2022
12ca1f1
add license
eserie Mar 16, 2022
e868ab4
format
eserie Mar 16, 2022
17ed82d
correct PctChange() to align with pandas behaviour. Introduce fillna_…
eserie Mar 17, 2022
54313f8
remove notebook 09
eserie Mar 18, 2022
1ba8e04
add numba to requirements
eserie Mar 23, 2022
1896f63
refactor accessors
eserie Mar 24, 2022
d9bb2e6
always have nobs state
eserie Mar 24, 2022
04bcc14
refactor numba
eserie Mar 24, 2022
6e02a2b
correct notebooks 1 & 2
eserie Mar 24, 2022
180bc0e
correct mypy
eserie Mar 25, 2022
38e972e
format notebook
eserie Mar 25, 2022
ae3de20
correct call to EWMA
eserie Mar 25, 2022
9e581f5
format notebook
eserie Mar 25, 2022
ab07e7f
remove versions in docs requirements
eserie Mar 25, 2022
a395831
move docs requirements in setup.cfg
eserie Mar 25, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 72 additions & 22 deletions docs/notebooks/01_demo_EWMA.ipynb

Large diffs are not rendered by default.

31 changes: 23 additions & 8 deletions docs/notebooks/01_demo_EWMA.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jupyter:
format_version: '1.3'
jupytext_version: 1.13.3
kernelspec:
display_name: Python 3
display_name: Python 3 (ipykernel)
language: python
name: python3
---
Expand Down Expand Up @@ -96,15 +96,30 @@ dataframe = dataset.air.to_series().unstack(["lon", "lat"])
### EWMA with pandas

```python
air_temp_ewma = dataframe.ewm(alpha=1.0 / 10.0).mean()
_ = air_temp_ewma.iloc[:, 0].plot()
air_temp_ewma = dataframe.ewm(com=10).mean()
_ = air_temp_ewma.mean(1).plot()
```

## wax numba ewma

```python
from wax.numba.ewma_numba import register_wax_numba
```

```python
register_wax_numba()
```

```python
air_temp_ewma = dataframe.wax_numba.ewm(com=10).mean()
_ = air_temp_ewma.mean(1).plot()
```

### EWMA with WAX-ML

```python
air_temp_ewma = dataframe.wax.ewm(alpha=1.0 / 10.0).mean()
_ = air_temp_ewma.iloc[:, 0].plot()
air_temp_ewma = dataframe.wax.ewm(com=10).mean()
_ = air_temp_ewma.mean(1).plot()
```

On small data, WAX-ML's EWMA is slower than Pandas' because of the expensive data conversion steps.
Expand All @@ -122,8 +137,8 @@ from wax.modules import EWMA

def my_custom_function(dataset):
return {
"air_10": EWMA(1.0 / 10.0)(dataset["air"]),
"air_100": EWMA(1.0 / 100.0)(dataset["air"]),
"air_10": EWMA(com=10)(dataset["air"]),
"air_100": EWMA(com=100)(dataset["air"]),
}


Expand All @@ -132,5 +147,5 @@ output, state = dataset.wax.stream().apply(
my_custom_function, format_dims=dataset.air.dims
)

_ = output.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().plot(figsize=(12, 8))
_ = output.isel(lat=0, lon=0).drop(["lat", "lon"]).to_dataframe().plot(figsize=(12, 8))
```
25 changes: 17 additions & 8 deletions docs/notebooks/01_demo_EWMA.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.5'
# jupytext_version: 1.13.3
# kernelspec:
# display_name: Python 3
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -84,13 +84,22 @@

# ### EWMA with pandas

air_temp_ewma = dataframe.ewm(alpha=1.0 / 10.0).mean()
_ = air_temp_ewma.iloc[:, 0].plot()
air_temp_ewma = dataframe.ewm(com=10).mean()
_ = air_temp_ewma.mean(1).plot()

# ## wax numba ewma

from wax.numba.ewma_numba import register_wax_numba

register_wax_numba()

air_temp_ewma = dataframe.wax_numba.ewm(com=10).mean()
_ = air_temp_ewma.mean(1).plot()

# ### EWMA with WAX-ML

air_temp_ewma = dataframe.wax.ewm(alpha=1.0 / 10.0).mean()
_ = air_temp_ewma.iloc[:, 0].plot()
air_temp_ewma = dataframe.wax.ewm(com=10).mean()
_ = air_temp_ewma.mean(1).plot()

# On small data, WAX-ML's EWMA is slower than Pandas' because of the expensive data conversion steps.
# WAX-ML's accessors are interesting to use on large data loads
Expand All @@ -106,8 +115,8 @@

def my_custom_function(dataset):
return {
"air_10": EWMA(1.0 / 10.0)(dataset["air"]),
"air_100": EWMA(1.0 / 100.0)(dataset["air"]),
"air_10": EWMA(com=10)(dataset["air"]),
"air_100": EWMA(com=100)(dataset["air"]),
}


Expand All @@ -116,4 +125,4 @@ def my_custom_function(dataset):
my_custom_function, format_dims=dataset.air.dims
)

_ = output.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().plot(figsize=(12, 8))
_ = output.isel(lat=0, lon=0).drop(["lat", "lon"]).to_dataframe().plot(figsize=(12, 8))
88 changes: 48 additions & 40 deletions docs/notebooks/02_Synchronize_data_streams.ipynb

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions docs/notebooks/02_Synchronize_data_streams.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jupyter:
format_version: '1.3'
jupytext_version: 1.13.3
kernelspec:
display_name: Python 3
display_name: Python 3 (ipykernel)
language: python
name: python3
---
Expand Down Expand Up @@ -90,9 +90,9 @@ from wax.modules import EWMA

def my_custom_function(dataset):
return {
"air_10": EWMA(1.0 / 10.0)(dataset["air"]),
"air_100": EWMA(1.0 / 100.0)(dataset["air"]),
"ground_100": EWMA(1.0 / 100.0)(dataset["ground"]),
"air_10": EWMA(com=10)(dataset["air"]),
"air_100": EWMA(com=100.0)(dataset["air"]),
"ground_100": EWMA(com=100.0)(dataset["ground"]),
}
```

Expand All @@ -103,5 +103,5 @@ results, state = dataset.wax.stream(
```

```python
_ = results.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().plot(figsize=(12, 8))
_ = results.isel(lat=0, lon=0).drop(["lat", "lon"]).to_dataframe().plot(figsize=(12, 8))
```
10 changes: 5 additions & 5 deletions docs/notebooks/02_Synchronize_data_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.5'
# jupytext_version: 1.13.3
# kernelspec:
# display_name: Python 3
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -82,9 +82,9 @@

def my_custom_function(dataset):
return {
"air_10": EWMA(1.0 / 10.0)(dataset["air"]),
"air_100": EWMA(1.0 / 100.0)(dataset["air"]),
"ground_100": EWMA(1.0 / 100.0)(dataset["ground"]),
"air_10": EWMA(com=10)(dataset["air"]),
"air_100": EWMA(com=100.0)(dataset["air"]),
"ground_100": EWMA(com=100.0)(dataset["ground"]),
}


Expand All @@ -94,4 +94,4 @@ def my_custom_function(dataset):
local_time="time", ffills={"day": 1}, pbar=True
).apply(my_custom_function, format_dims=dataset.air.dims)

_ = results.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().plot(figsize=(12, 8))
_ = results.isel(lat=0, lon=0).drop(["lat", "lon"]).to_dataframe().plot(figsize=(12, 8))
12 changes: 1 addition & 11 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1 @@
# sphinx <4 required by myst-nb v0.12.0 (Feb 2021)
# sphinx >=3 required by sphinx-autodoc-typehints v1.11.1 (Oct 2020)
sphinx >=3, <4
sphinx_rtd_theme
sphinx-autodoc-typehints==1.11.1
jupyter-sphinx>=0.3.2
myst-nb
# Packages used for notebook execution
matplotlib
sklearn
.[dev,complete]
.[dev,complete,docs]
10 changes: 9 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ install_requires =
jaxlib>=0.1.67
jax<=0.2.21
dm-haiku >= 0.0.4
numba

[options.extras_require]
optional =
Expand Down Expand Up @@ -81,13 +82,17 @@ docs =
sphinx
sphinxcontrib-napoleon
sphinx_rtd_theme
sphinx-autodoc-typehints
sphinx-autosummary-accessors
ipython
ipykernel
jupyter-client
jupyter-sphinx
myst-nb
nbsphinx
scanpydoc

matplotlib
sklearn
[options.package_data]
wax =
py.typed
Expand Down Expand Up @@ -207,6 +212,9 @@ ignore_errors = True
[mypy-numpy.*]
ignore_missing_imports = True

[mypy-numba.*]
ignore_missing_imports = True

[mypy-opt_einsum.*]
ignore_missing_imports = True

Expand Down
21 changes: 11 additions & 10 deletions wax/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Define accessors for xarray and pandas data containers."""
from dataclasses import dataclass, field
from typing import Any, Callable, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union

import jax.numpy as jnp
import numpy as onp
Expand Down Expand Up @@ -315,22 +315,23 @@ def ewm(self, *args, **kwargs):
@dataclass(frozen=True)
class ExponentialMovingWindow:
accessor: WaxAccessor
alpha: float
com: Optional[float] = None
alpha: Optional[float] = None
min_periods: int = 0
adjust: bool = True
ignore_na: bool = False
initial_value: float = jnp.nan
return_state: bool = False
format_outputs: bool = True

def mean(self):
from wax.modules import EWMA

def _apply_ema(
accessor, alpha, adjust, params=None, state=None, *args, **kwargs
):
return accessor.stream(*args, **kwargs).apply(
lambda x: EWMA(alpha, adjust)(x),
params=params,
state=state,
rng=None,
def _apply_ema(*, accessor, return_state, format_outputs, **kwargs):
return accessor.stream(
return_state=return_state, format_outputs=format_outputs
).apply(
lambda x: EWMA(**kwargs)(x),
)

return _apply_ema(**self.__dict__)
Expand Down
2 changes: 1 addition & 1 deletion wax/accessors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def module_map(x):


def check_ema_state(state, ref_count=124):
assert (state["ewma"]["count"] == ref_count).all()
assert (state["ewma"]["nobs"] == ref_count).all()


def prepare_format_data(format):
Expand Down
Loading